diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6572e421b65b..d4afaf51e6e8 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -86,7 +86,7 @@ from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.kv_connector_model_runner_mixin import ( - KVConnectorModelRunnerMixin, KVConnectorOutput) + KVConnectorModelRunnerMixin) from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin from .utils import (AttentionGroup, MultiModalBudget, @@ -196,6 +196,14 @@ def __init__( self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_reqs = scheduler_config.max_num_seqs + # Broadcast PP output for external_launcher (torchrun) + # to make sure we are synced across pp ranks + # TODO: Support overlapping mirco-batches + # https://github.com/vllm-project/vllm/issues/18019 + self.broadcast_pp_output = ( + self.parallel_config.distributed_executor_backend + == "external_launcher" and len(get_pp_group().ranks) > 0) + # Model-related. self.num_query_heads = model_config.get_num_attention_heads( parallel_config) @@ -1701,7 +1709,6 @@ def _pool( hidden_states: torch.Tensor, num_scheduled_tokens: int, num_scheduled_tokens_np: np.ndarray, - kv_connector_output: Optional[KVConnectorOutput], ) -> ModelRunnerOutput: assert self.input_batch.num_reqs ==\ len(self.input_batch.pooling_params), \ @@ -1732,7 +1739,6 @@ def _pool( logprobs=None, prompt_logprobs_dict={}, pooler_output=pooler_output, - kv_connector_output=kv_connector_output, ) def _preprocess( @@ -2073,39 +2079,47 @@ def execute_model( with record_function_or_nullcontext("Postprocess"): if self.use_aux_hidden_state_outputs: + # True when EAGLE 3 is used. hidden_states, aux_hidden_states = model_output else: + # Common case. hidden_states = model_output aux_hidden_states = None - # Broadcast PP output for external_launcher (torchrun) - # to make sure we are synced across pp ranks - # TODO: Support overlapping mirco-batches - # https://github.com/vllm-project/vllm/issues/18019 - broadcast_pp_output = \ - self.parallel_config.distributed_executor_backend \ - == "external_launcher" and len(get_pp_group().ranks) > 0 - if not get_pp_group().is_last_rank: - # For mid-pipeline stages, return the hidden states. - assert isinstance(hidden_states, IntermediateTensors) - if not broadcast_pp_output: + if not self.broadcast_pp_output: + # Common case. + if not get_pp_group().is_last_rank: + # Return the intermediate tensors. + assert isinstance(hidden_states, IntermediateTensors) hidden_states.kv_connector_output = kv_connector_output return hidden_states - get_pp_group().send_tensor_dict( - hidden_states.tensors, all_gather_group=get_tp_group()) - logits = None - else: + if self.is_pooling_model: - return self._pool(hidden_states, num_scheduled_tokens, - num_scheduled_tokens_np, - kv_connector_output) + # Return the pooling output. + output = self._pool(hidden_states, num_scheduled_tokens, + num_scheduled_tokens_np) + output.kv_connector_output = kv_connector_output + return output sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) - if broadcast_pp_output: - model_output_broadcast_data = { - "logits": logits.contiguous(), - } if logits is not None else {} + else: + # Rare case. + assert not self.is_pooling_model + + if not get_pp_group().is_last_rank: + get_pp_group().send_tensor_dict( + hidden_states.tensors, all_gather_group=get_tp_group()) + logits = None + else: + sample_hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(sample_hidden_states, + None) + + model_output_broadcast_data = {} + if logits is not None: + model_output_broadcast_data["logits"] = logits.contiguous() + model_output_broadcast_data = get_pp_group( ).broadcast_tensor_dict(model_output_broadcast_data, src=len(get_pp_group().ranks) - 1)