Skip to content

Conversation

elvircrn
Copy link
Contributor

@elvircrn elvircrn commented Sep 1, 2025

Purpose

The purpose of this PR is to replace the triton silu implementation with a faster cuda version.

Here's a benchmark slice:

silu_benchmark_total_tokens_10

This was achieved by launching additional cuda blocks and parallelizing over T dimension. This means that we end up launching NOOP threads and that the parallelization factor is now an additional tunable parameter.

To understand the impact of the parallelization factor, see the following graphs for E <=9 :

silu_benchmark_experts 0  = T_2 silu_benchmark_experts  = T_1 silu_benchmark_experts = torch randint(0, T, size=(E,)); sort(experts)_3 silu_benchmark_experts = torch randint(0, T, size=(E,))_0

For E=32, we have:
silu_benchmark_experts = torch randint(0, T, size=(E,))_0
silu_benchmark_experts = torch randint(0, T, size=(E,)); sort(experts)_3
silu_benchmark_experts  = T_1
silu_benchmark_experts 0  = T_2

16X seems like a parallelization that works well for most configuration - this was chosen as the default.

Test Plan

Given y of shape (E, T, 2H) as input the function is expected to work for all GROUP_SIZE=128 and H divisible by 128.

Test Result

VLLM_ALL2ALL_BACKEND="deepep_low_latency" VLLM_USE_DEEP_GEMM=1 g2 lm_eval --model vllm --model_args pretrained=Qwen/Qwen3-30B-A3B-Instruct-2507-FP8,data_parallel_size=2,enable_expert_parallel=True --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto
vllm (pretrained=Qwen/Qwen3-30B-A3B-Instruct-2507-FP8,data_parallel_size=2,enable_expert_parallel=True,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8848|±  |0.0088|
|     |       |strict-match    |     5|exact_match|↑  |0.8711|±  |0.0092|

From main:

vllm (pretrained=Qwen/Qwen3-30B-A3B-Instruct-2507-FP8,data_parallel_size=2,enable_expert_parallel=True,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8848|±  |0.0088|
|     |       |strict-match    |     5|exact_match|↑  |0.8711|±  |0.0092|

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added the performance Performance-related issues label Sep 1, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new CUDA kernel for silu_mul_fp8_quant to replace the existing Triton implementation, aiming for better performance. The changes include the new CUDA kernel, its C++ bindings, and updates to tests and benchmarks to compare against the old implementation, which is preserved as a baseline. While the overall approach is sound, the review identified several critical correctness issues in the new CUDA kernel related to parameter handling, as well as high-severity maintainability problems such as dead code, code duplication, and style violations. These issues should be addressed to ensure the correctness and long-term health of the codebase.


// quant params
float fp8_min, float fp8_max) {
static constexpr float EPS = 1e-10;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The kernel uses a hardcoded EPS value, ignoring the eps parameter passed to the host function silu_mul_fp8_quant_deep_gemm_cuda. This is a correctness bug. The eps value should be passed to this kernel as an argument and used instead of the hardcoded constant.

This will require changes in:

  1. The kernel signature to accept eps.
  2. The kernel body to use the eps argument (e.g., float y_max = eps; on line 265).
  3. The host launcher silu_mul_fp8_quant_deep_gemm_cuda to pass eps to the kernel.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LucasWilkinson Is this OK?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should remove eps/fp8_max/fp8_min from the python API?

@elvircrn elvircrn marked this pull request as draft September 1, 2025 16:05
@robertgshaw2-redhat robertgshaw2-redhat changed the title Squashed cuda silu changes [Kernels][DP/EP] Optimize Silu Kernel for R1 Sep 2, 2025
@robertgshaw2-redhat
Copy link
Collaborator

  • I got the following:
(APIServer pid=1) (EngineCore_0 pid=283) RuntimeError: Worker failed with error ''Keyword argument NUM_WARPS was specified but unrecognised'', please check the stack trace above for the root cause
  • root cause higher up in the stack trace
(APIServer pid=1) (EngineCore_7 pid=304) (VllmWorker pid=368) ERROR 09-03 01:11:27 [multiproc_executor.py:611]     a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1,
(APIServer pid=1) (EngineCore_7 pid=304) (VllmWorker pid=368) ERROR 09-03 01:11:27 [multiproc_executor.py:611]                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(APIServer pid=1) (EngineCore_7 pid=304) (VllmWorker pid=368) ERROR 09-03 01:11:27 [multiproc_executor.py:611]   File "/opt/vllm-source/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py", line 206, in silu_mul_fp8_quant_deep_gemm
(APIServer pid=1) (EngineCore_7 pid=304) (VllmWorker pid=368) ERROR 09-03 01:11:27 [multiproc_executor.py:611]     _silu_mul_fp8_quant_deep_gemm[grid](
(APIServer pid=1) (EngineCore_7 pid=304) (VllmWorker pid=368) ERROR 09-03 01:11:27 [multiproc_executor.py:611]   File "/opt/vllm/lib64/python3.12/site-packages/triton/runtime/jit.py", line 390, in <lambda>
(APIServer pid=1) (EngineCore_7 pid=304) (VllmWorker pid=368) ERROR 09-03 01:11:27 [multiproc_executor.py:611]     return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
(APIServer pid=1) (EngineCore_7 pid=304) (VllmWorker pid=368) ERROR 09-03 01:11:27 [multiproc_executor.py:611]                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

It looks like the triton kernel is still hooked up to the actual layer.

@robertgshaw2-redhat
Copy link
Collaborator

robertgshaw2-redhat commented Sep 3, 2025

made branch on nm repo that hooks up the kernel: 376129e

I still have an IMA:

RuntimeError: CUDA driver error: an illegal memory access was encountered

Im trying to build it locally

@elvircrn elvircrn force-pushed the cuda_silu_cherry_picked branch 2 times, most recently from 386a373 to f7820e5 Compare September 3, 2025 11:28
@elvircrn elvircrn marked this pull request as ready for review September 3, 2025 12:40
@elvircrn elvircrn force-pushed the cuda_silu_cherry_picked branch from 1a43afc to 7d9ad60 Compare September 3, 2025 12:43
@LucasWilkinson
Copy link
Collaborator

made branch on nm repo that hooks up the kernel: 376129e

I still have an IMA:

RuntimeError: CUDA driver error: an illegal memory access was encountered

Im trying to build it locally

@elvircrn would a coredump help? https://blog.vllm.ai/2025/08/11/cuda-debugging.html @robertgshaw2-redhat would you be able to enable this environment variables?

@elvircrn
Copy link
Contributor Author

elvircrn commented Sep 3, 2025

@LucasWilkinson Yes, a coredump would help, thanks!

@elvircrn
Copy link
Contributor Author

elvircrn commented Sep 3, 2025

@robertgshaw2-redhat @LucasWilkinson Can H ever not be divisible by 128?

@elvircrn elvircrn force-pushed the cuda_silu_cherry_picked branch from b9f6303 to 49aab19 Compare September 3, 2025 16:52
@elvircrn
Copy link
Contributor Author

elvircrn commented Sep 5, 2025

/gemini-review

Copy link
Contributor

I've received your request for a new review. Please note that this is a command for the system to initiate a new review, and I will not be providing a conversational response here. The review process will begin shortly.

@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 10, 2025
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, great job and very thorough work!

@LucasWilkinson
Copy link
Collaborator

@elvircrn the CI build failures seem related to this PR

@elvircrn
Copy link
Contributor Author

elvircrn commented Sep 11, 2025

@LucasWilkinson pre-commit check failure unrelated to this PR.

@elvircrn elvircrn closed this Sep 11, 2025
@elvircrn elvircrn reopened this Sep 11, 2025
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) September 11, 2025 21:38
@elvircrn
Copy link
Contributor Author

@LucasWilkinson @tlrmchlsmth

  1. Blackwell failing test seems to be resolved with: [CI Failure] Fix test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe #24750 but fails on main
  2. Neuron soft failing with: bash: .buildkite/scripts/hardware_ci/run-neuron-test.sh: No such file or directory also on main.
  3. TPU tests also failing on main

Used this for reference a5b84f1

@tlrmchlsmth
Copy link
Collaborator

@LucasWilkinson @tlrmchlsmth

  1. Blackwell failing test seems to be resolved with: [CI Failure] Fix test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe #24750 but fails on main
  2. Neuron soft failing with: bash: .buildkite/scripts/hardware_ci/run-neuron-test.sh: No such file or directory also on main.
  3. TPU tests also failing on main

Used this for reference a5b84f1

sounds good - let's wait for distributed-tests-2-gpus and entrypoints-integration-test-api-server to finish and then we can request a force merge

@elvircrn
Copy link
Contributor Author

@tlrmchlsmth @LucasWilkinson CI is done.

@vllm-bot vllm-bot merged commit 98229db into vllm-project:main Sep 13, 2025
72 of 74 checks passed
BoyuanFeng pushed a commit to BoyuanFeng/vllm that referenced this pull request Sep 14, 2025
845473182 pushed a commit to dsxsteven/vllm_splitPR that referenced this pull request Sep 15, 2025
…to loader

* 'loader' of https://github.com/dsxsteven/vllm_splitPR: (123 commits)
  [Hybrid Allocator] Support Pipeline Parallel (vllm-project#23974)
  [Spec Decoding]Support Spec Decoding Metrics in DP Mode (vllm-project#24049)
  [Chore] Remove ipex_ops warning (vllm-project#24835)
  Force use C++17 globally to avoid compilation error (vllm-project#24823)
  [Benchmarks] Throw usage error when using dataset-name random and dataset-path together (vllm-project#24819)
  fix type of sampling rate for encode_base64 (vllm-project#24826)
  [Perf] Fix DeepGEMM Contiguous Layout Issue, 5.5% Throughput Improvement (vllm-project#24783)
  [Misc] Improve `s3_utils` type hints with `BaseClient` (vllm-project#24825)
  [Multi Modal][Performance] Fused Q,K's apply_rope into one (vllm-project#24511)
  [Chore] Minor simplification for non-PP path (vllm-project#24810)
  [Minor] Simplify duplicative device check for cuda (vllm-project#24793)
  Remove redundant assignment in xfer_buffers, This is a little fix (vllm-project#24732)
  [CI][Spec Decode] Adjust threshold for flaky ngram spec decoding test again (vllm-project#24771)
  [Doc]: fix typos in various files (vllm-project#24798)
  [Misc] Correct an outdated comment. (vllm-project#24765)
  [CI Failure] Fix test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe (vllm-project#24750)
  [Core][Multimodal] Cache `supports_kw` (vllm-project#24773)
  [Kernels][DP/EP] Optimize Silu Kernel for R1 (vllm-project#24054)
  [Perf] Use NVIDIA hardware-accelerated instruction for float to fp8_e4m3 quantization (vllm-project#24757)
  [Doc]: Remove 404 hyperlinks (vllm-project#24785)
  ...
dsxsteven pushed a commit to dsxsteven/vllm_splitPR that referenced this pull request Sep 15, 2025
bbartels pushed a commit to bbartels/vllm that referenced this pull request Sep 15, 2025
cboss6 pushed a commit to cboss6/vllm that referenced this pull request Sep 16, 2025
cboss6 pushed a commit to cboss6/vllm that referenced this pull request Sep 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants