Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 39 additions & 25 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.kv_connector_model_runner_mixin import (
KVConnectorModelRunnerMixin, KVConnectorOutput)
KVConnectorModelRunnerMixin)
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin

from .utils import (AttentionGroup, MultiModalBudget,
Expand Down Expand Up @@ -196,6 +196,14 @@ def __init__(
self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs

# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# TODO: Support overlapping mirco-batches
# https://github.com/vllm-project/vllm/issues/18019
self.broadcast_pp_output = (
self.parallel_config.distributed_executor_backend
== "external_launcher" and len(get_pp_group().ranks) > 0)

# Model-related.
self.num_query_heads = model_config.get_num_attention_heads(
parallel_config)
Expand Down Expand Up @@ -1701,7 +1709,6 @@ def _pool(
hidden_states: torch.Tensor,
num_scheduled_tokens: int,
num_scheduled_tokens_np: np.ndarray,
kv_connector_output: Optional[KVConnectorOutput],
) -> ModelRunnerOutput:
assert self.input_batch.num_reqs ==\
len(self.input_batch.pooling_params), \
Expand Down Expand Up @@ -1732,7 +1739,6 @@ def _pool(
logprobs=None,
prompt_logprobs_dict={},
pooler_output=pooler_output,
kv_connector_output=kv_connector_output,
)

def _preprocess(
Expand Down Expand Up @@ -2073,39 +2079,47 @@ def execute_model(

with record_function_or_nullcontext("Postprocess"):
if self.use_aux_hidden_state_outputs:
# True when EAGLE 3 is used.
hidden_states, aux_hidden_states = model_output
else:
# Common case.
hidden_states = model_output
aux_hidden_states = None

# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# TODO: Support overlapping mirco-batches
# https://github.com/vllm-project/vllm/issues/18019
broadcast_pp_output = \
self.parallel_config.distributed_executor_backend \
== "external_launcher" and len(get_pp_group().ranks) > 0
if not get_pp_group().is_last_rank:
# For mid-pipeline stages, return the hidden states.
assert isinstance(hidden_states, IntermediateTensors)
if not broadcast_pp_output:
if not self.broadcast_pp_output:
# Common case.
if not get_pp_group().is_last_rank:
# Return the intermediate tensors.
assert isinstance(hidden_states, IntermediateTensors)
hidden_states.kv_connector_output = kv_connector_output
return hidden_states
get_pp_group().send_tensor_dict(
hidden_states.tensors, all_gather_group=get_tp_group())
logits = None
else:

if self.is_pooling_model:
return self._pool(hidden_states, num_scheduled_tokens,
num_scheduled_tokens_np,
kv_connector_output)
# Return the pooling output.
output = self._pool(hidden_states, num_scheduled_tokens,
num_scheduled_tokens_np)
output.kv_connector_output = kv_connector_output
return output

sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
if broadcast_pp_output:
model_output_broadcast_data = {
"logits": logits.contiguous(),
} if logits is not None else {}
else:
Copy link
Member

@njhill njhill Sep 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a thought, could we put the else logic in a different function to be even less intrusive to the common path?

# Rare case.
assert not self.is_pooling_model

if not get_pp_group().is_last_rank:
get_pp_group().send_tensor_dict(
hidden_states.tensors, all_gather_group=get_tp_group())
logits = None
else:
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states,
None)

model_output_broadcast_data = {}
if logits is not None:
model_output_broadcast_data["logits"] = logits.contiguous()

model_output_broadcast_data = get_pp_group(
).broadcast_tensor_dict(model_output_broadcast_data,
src=len(get_pp_group().ranks) - 1)
Expand Down