Skip to content

Commit 3e903b6

Browse files
authored
[Chore] Minor simplification for non-PP path (#24810)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent 973c9d0 commit 3e903b6

File tree

1 file changed

+39
-25
lines changed

1 file changed

+39
-25
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@
8686
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
8787
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
8888
from vllm.v1.worker.kv_connector_model_runner_mixin import (
89-
KVConnectorModelRunnerMixin, KVConnectorOutput)
89+
KVConnectorModelRunnerMixin)
9090
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
9191

9292
from .utils import (AttentionGroup, MultiModalBudget,
@@ -196,6 +196,14 @@ def __init__(
196196
self.max_num_tokens = scheduler_config.max_num_batched_tokens
197197
self.max_num_reqs = scheduler_config.max_num_seqs
198198

199+
# Broadcast PP output for external_launcher (torchrun)
200+
# to make sure we are synced across pp ranks
201+
# TODO: Support overlapping mirco-batches
202+
# https://github.com/vllm-project/vllm/issues/18019
203+
self.broadcast_pp_output = (
204+
self.parallel_config.distributed_executor_backend
205+
== "external_launcher" and len(get_pp_group().ranks) > 0)
206+
199207
# Model-related.
200208
self.num_query_heads = model_config.get_num_attention_heads(
201209
parallel_config)
@@ -1701,7 +1709,6 @@ def _pool(
17011709
hidden_states: torch.Tensor,
17021710
num_scheduled_tokens: int,
17031711
num_scheduled_tokens_np: np.ndarray,
1704-
kv_connector_output: Optional[KVConnectorOutput],
17051712
) -> ModelRunnerOutput:
17061713
assert self.input_batch.num_reqs ==\
17071714
len(self.input_batch.pooling_params), \
@@ -1732,7 +1739,6 @@ def _pool(
17321739
logprobs=None,
17331740
prompt_logprobs_dict={},
17341741
pooler_output=pooler_output,
1735-
kv_connector_output=kv_connector_output,
17361742
)
17371743

17381744
def _preprocess(
@@ -2073,39 +2079,47 @@ def execute_model(
20732079

20742080
with record_function_or_nullcontext("Postprocess"):
20752081
if self.use_aux_hidden_state_outputs:
2082+
# True when EAGLE 3 is used.
20762083
hidden_states, aux_hidden_states = model_output
20772084
else:
2085+
# Common case.
20782086
hidden_states = model_output
20792087
aux_hidden_states = None
20802088

2081-
# Broadcast PP output for external_launcher (torchrun)
2082-
# to make sure we are synced across pp ranks
2083-
# TODO: Support overlapping mirco-batches
2084-
# https://github.com/vllm-project/vllm/issues/18019
2085-
broadcast_pp_output = \
2086-
self.parallel_config.distributed_executor_backend \
2087-
== "external_launcher" and len(get_pp_group().ranks) > 0
2088-
if not get_pp_group().is_last_rank:
2089-
# For mid-pipeline stages, return the hidden states.
2090-
assert isinstance(hidden_states, IntermediateTensors)
2091-
if not broadcast_pp_output:
2089+
if not self.broadcast_pp_output:
2090+
# Common case.
2091+
if not get_pp_group().is_last_rank:
2092+
# Return the intermediate tensors.
2093+
assert isinstance(hidden_states, IntermediateTensors)
20922094
hidden_states.kv_connector_output = kv_connector_output
20932095
return hidden_states
2094-
get_pp_group().send_tensor_dict(
2095-
hidden_states.tensors, all_gather_group=get_tp_group())
2096-
logits = None
2097-
else:
2096+
20982097
if self.is_pooling_model:
2099-
return self._pool(hidden_states, num_scheduled_tokens,
2100-
num_scheduled_tokens_np,
2101-
kv_connector_output)
2098+
# Return the pooling output.
2099+
output = self._pool(hidden_states, num_scheduled_tokens,
2100+
num_scheduled_tokens_np)
2101+
output.kv_connector_output = kv_connector_output
2102+
return output
21022103

21032104
sample_hidden_states = hidden_states[logits_indices]
21042105
logits = self.model.compute_logits(sample_hidden_states, None)
2105-
if broadcast_pp_output:
2106-
model_output_broadcast_data = {
2107-
"logits": logits.contiguous(),
2108-
} if logits is not None else {}
2106+
else:
2107+
# Rare case.
2108+
assert not self.is_pooling_model
2109+
2110+
if not get_pp_group().is_last_rank:
2111+
get_pp_group().send_tensor_dict(
2112+
hidden_states.tensors, all_gather_group=get_tp_group())
2113+
logits = None
2114+
else:
2115+
sample_hidden_states = hidden_states[logits_indices]
2116+
logits = self.model.compute_logits(sample_hidden_states,
2117+
None)
2118+
2119+
model_output_broadcast_data = {}
2120+
if logits is not None:
2121+
model_output_broadcast_data["logits"] = logits.contiguous()
2122+
21092123
model_output_broadcast_data = get_pp_group(
21102124
).broadcast_tensor_dict(model_output_broadcast_data,
21112125
src=len(get_pp_group().ranks) - 1)

0 commit comments

Comments
 (0)