Skip to content

Conversation

SageMoore
Copy link
Contributor

@SageMoore SageMoore commented Aug 26, 2025

Purpose

This PR adds support for Dual-Batch Overlap in VLLM. In it's current state it will only be abled when a user provides the --enable-microbatching flag. Furthermore, it will only be used when all DP groups are running full-decode batches. This PR supports running DBO with full cudagraphs, which is essential for minimizing the CPU overhead and getting performance from this feature.

To implement Dual-Batch Overlap (DBO), at a high level, we split the batch into two microbatches. Then using two threads and two cuda streams, one for communication and one for computation, to overlap the dispatch and combine all-to-all kernels of one microbatch with the compute kernels of the other microbatch.

When microbatching is enabled and supported, the GPUModelRunner will split the batch into two token_slices. These token_slices are then passed into the attention meta data builders during _prepare_inputs to generate one attention metadata object per-microbatch. When actually running the model, the model runner will spawn off two microbatching threads that will each communicate with each other using a UBatchContext. Each of these threads will then run self.model with the appropriate attention meta data.

Without any additional modifications to the code, this will just result in one microbatch running to completion before the other microbatch starts. In order to get overlaps, we've added a "yield" call that can be inserted into the all-to-all kernels to interleave the two microbatches. The yield_and_switch_from_compute_to_comm function yield the CPU from this thread (thread A) to the other microbatching thread (thread B). Once thread A has resumed execution, either because thread B yielded the CPU or finished it's execution, it will swap over to the communication stream and start dispatching kernels there. yield_and_switch_from_comm_to_compute behaves similarly but in the opposite direction. It swaps from the communication stream to the compute stream.

There are both GPU and CPU events to synchronize all of this. That being said, it is absolutely critical that only one microbatching thread is running at a time, meaning the other one is waiting on an event. It is also absolutely critical that both microbatches are running the exact same number of yields.

Test Plan

In general my test plan was to run lm_eval with deepseek-ai/DeepSeek-V2-Lite. We've also run numerous times with R1 in a multi node setup and verified that lm_eval produces reasonable output.

Non-DBO Runs

Eager

Command

VLLM_ALL2ALL_BACKEND=deepep_low_latency vllm serve --model="deepseek-ai/DeepSeek-V2-Lite" --data-parallel-size 2 --enable-expert-parallel --enforce-eager

Result
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.3567|±  |0.0277|
|     |       |strict-match    |     5|exact_match|↑  |0.3533|±  |0.0276|

Default

Command

VLLM_ALL2ALL_BACKEND=deepep_low_latency g2 vllm serve --model="deepseek-ai/DeepSeek-V2-Lite" --data-parallel-size 2 --enable-expert-parallel

Result
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.3700|±  |0.0279|
|     |       |strict-match    |     5|exact_match|↑  |0.3667|±  |0.0279|

DBO Runs

Eager

Command

VLLM_ALL2ALL_BACKEND=deepep_low_latency g2 vllm serve --model="deepseek-ai/DeepSeek-V2-Lite" --data-parallel-size 2 --enable-expert-parallel --enforce-eager --enable-microbatching --microbatching-token-threshold 4

Result
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.3800|±  |0.0281|
|     |       |strict-match    |     5|exact_match|↑  |0.3767|±  |0.0280|

Full cudagraphs

Command

VLLM_ALL2ALL_BACKEND=deepep_low_latency g2 vllm serve --model="deepseek-ai/DeepSeek-V2-Lite" --data-parallel-size 2 --enable-expert-parallel --compilation_config '{"cudagraph_mode": "full_decode_only"}' --enable-microbatching --microbatching-token-threshold 4

Result
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.3733|±  |0.0280|
|     |       |strict-match    |     5|exact_match|↑  |0.3700|±  |0.0279|

LucasWilkinson and others added 30 commits May 22, 2025 20:51
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Sage Moore <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Comment on lines +698 to +700
parallel_group.add_argument(
"--dbo-decode-token-threshold",
**parallel_kwargs["dbo_decode_token_threshold"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the future plan for this argument? Will we add a separate --dbo-prefill-token-threshold? Could there be one argument instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep we are planning to add a prefill version of this argument.

Comment on lines +531 to +533
fused_out_buffer = SharedResizableBuffer()
workspace13_buffer = SharedResizableBuffer()
workspace2_buffer = SharedResizableBuffer()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a general memory footprint reduction. Primarily targeting cudagraphs, though.

Comment on lines 230 to 234
# if we are using mrope
if positions.ndim == 2:
sliced_positions = positions[:, tokens_slice]
else:
sliced_positions = positions[tokens_slice]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the mrope interaction? Could we add a comment explaining it here?

Copy link
Contributor Author

@SageMoore SageMoore Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's largely just that mrope adds an additional dimension to the positions tensor so we need to slice the lower dimension. I'll add a comment.

Comment on lines 73 to 78
# Sanity Check that the existing padding isn't giving us an empty second
# ubatch. Abort if so
if is_second_ubatch_empty(num_tokens_unpadded, num_tokens_padded):
logger.debug("Aborting ubatching %s %s", num_tokens_unpadded,
num_tokens_padded)
should_ubatch = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this something that's expected to happen sometimes, and that's OK? If not, I think this should be a warning instead.

And then could you add a bit more detail to the log, e.g.

Suggested change
# Sanity Check that the existing padding isn't giving us an empty second
# ubatch. Abort if so
if is_second_ubatch_empty(num_tokens_unpadded, num_tokens_padded):
logger.debug("Aborting ubatching %s %s", num_tokens_unpadded,
num_tokens_padded)
should_ubatch = False
# Sanity Check that the existing padding isn't giving us an empty second
# ubatch. Abort if so
if is_second_ubatch_empty(num_tokens_unpadded, num_tokens_padded):
logger.warning("Empty second µbatch detected: unpadded tokens: %s, padded tokens: %s",
num_tokens_unpadded,
num_tokens_padded)
should_ubatch = False

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is expected to happen and isn't necessarily a bug when it does. I find the debug log to be really helpful when debugging misc padding issues. We can certainly take it out, though.

Signed-off-by: Sage Moore <[email protected]>
Signed-off-by: Sage Moore <[email protected]>
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few minor comments. Overall I think this is ready to land otherwise. Maybe a little rough around the edges with the model runner changes but will be great to have this landed on main, especially as we have a prefill DBO PR ready to be reviewed as soon as this one lands.

Signed-off-by: Sage Moore <[email protected]>
Signed-off-by: Sage Moore <[email protected]>
Copy link

mergify bot commented Sep 15, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @SageMoore.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 15, 2025
@mergify mergify bot removed the needs-rebase label Sep 15, 2025
Copy link

mergify bot commented Sep 16, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @SageMoore.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 16, 2025
@tlrmchlsmth
Copy link
Collaborator

tlrmchlsmth commented Sep 16, 2025

I thought the kernels-moe-test failures were due to VLLM_USE_PRECOMPILED=1 not picking up the changes from #24054, but that was from 3 days ago

Could it be a real problem? @elvircrn, @dougbtv

AttributeError:` '_OpNamespace' '_C' object has no attribute 'silu_mul_fp8_quant_deep_gemm_cuda

Edit: Confirmed it's picking up old binaries.

@mergify mergify bot removed the needs-rebase label Sep 16, 2025
@tlrmchlsmth tlrmchlsmth merged commit 5679399 into vllm-project:main Sep 16, 2025
54 checks passed
@NihalPotdar
Copy link

Hey! Quick question - do you have any performance numbers for this change?

Mainly wondering about the efficiency of the communication-computation overlap strategy in the PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed speculative-decoding v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.