|
86 | 86 | from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
|
87 | 87 | from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
88 | 88 | from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
89 |
| - KVConnectorModelRunnerMixin, KVConnectorOutput) |
| 89 | + KVConnectorModelRunnerMixin) |
90 | 90 | from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
91 | 91 |
|
92 | 92 | from .utils import (AttentionGroup, MultiModalBudget,
|
@@ -196,6 +196,14 @@ def __init__(
|
196 | 196 | self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
197 | 197 | self.max_num_reqs = scheduler_config.max_num_seqs
|
198 | 198 |
|
| 199 | + # Broadcast PP output for external_launcher (torchrun) |
| 200 | + # to make sure we are synced across pp ranks |
| 201 | + # TODO: Support overlapping mirco-batches |
| 202 | + # https://github.com/vllm-project/vllm/issues/18019 |
| 203 | + self.broadcast_pp_output = ( |
| 204 | + self.parallel_config.distributed_executor_backend |
| 205 | + == "external_launcher" and len(get_pp_group().ranks) > 0) |
| 206 | + |
199 | 207 | # Model-related.
|
200 | 208 | self.num_query_heads = model_config.get_num_attention_heads(
|
201 | 209 | parallel_config)
|
@@ -1701,7 +1709,6 @@ def _pool(
|
1701 | 1709 | hidden_states: torch.Tensor,
|
1702 | 1710 | num_scheduled_tokens: int,
|
1703 | 1711 | num_scheduled_tokens_np: np.ndarray,
|
1704 |
| - kv_connector_output: Optional[KVConnectorOutput], |
1705 | 1712 | ) -> ModelRunnerOutput:
|
1706 | 1713 | assert self.input_batch.num_reqs ==\
|
1707 | 1714 | len(self.input_batch.pooling_params), \
|
@@ -1732,7 +1739,6 @@ def _pool(
|
1732 | 1739 | logprobs=None,
|
1733 | 1740 | prompt_logprobs_dict={},
|
1734 | 1741 | pooler_output=pooler_output,
|
1735 |
| - kv_connector_output=kv_connector_output, |
1736 | 1742 | )
|
1737 | 1743 |
|
1738 | 1744 | def _preprocess(
|
@@ -2073,39 +2079,47 @@ def execute_model(
|
2073 | 2079 |
|
2074 | 2080 | with record_function_or_nullcontext("Postprocess"):
|
2075 | 2081 | if self.use_aux_hidden_state_outputs:
|
| 2082 | + # True when EAGLE 3 is used. |
2076 | 2083 | hidden_states, aux_hidden_states = model_output
|
2077 | 2084 | else:
|
| 2085 | + # Common case. |
2078 | 2086 | hidden_states = model_output
|
2079 | 2087 | aux_hidden_states = None
|
2080 | 2088 |
|
2081 |
| - # Broadcast PP output for external_launcher (torchrun) |
2082 |
| - # to make sure we are synced across pp ranks |
2083 |
| - # TODO: Support overlapping mirco-batches |
2084 |
| - # https://github.com/vllm-project/vllm/issues/18019 |
2085 |
| - broadcast_pp_output = \ |
2086 |
| - self.parallel_config.distributed_executor_backend \ |
2087 |
| - == "external_launcher" and len(get_pp_group().ranks) > 0 |
2088 |
| - if not get_pp_group().is_last_rank: |
2089 |
| - # For mid-pipeline stages, return the hidden states. |
2090 |
| - assert isinstance(hidden_states, IntermediateTensors) |
2091 |
| - if not broadcast_pp_output: |
| 2089 | + if not self.broadcast_pp_output: |
| 2090 | + # Common case. |
| 2091 | + if not get_pp_group().is_last_rank: |
| 2092 | + # Return the intermediate tensors. |
| 2093 | + assert isinstance(hidden_states, IntermediateTensors) |
2092 | 2094 | hidden_states.kv_connector_output = kv_connector_output
|
2093 | 2095 | return hidden_states
|
2094 |
| - get_pp_group().send_tensor_dict( |
2095 |
| - hidden_states.tensors, all_gather_group=get_tp_group()) |
2096 |
| - logits = None |
2097 |
| - else: |
| 2096 | + |
2098 | 2097 | if self.is_pooling_model:
|
2099 |
| - return self._pool(hidden_states, num_scheduled_tokens, |
2100 |
| - num_scheduled_tokens_np, |
2101 |
| - kv_connector_output) |
| 2098 | + # Return the pooling output. |
| 2099 | + output = self._pool(hidden_states, num_scheduled_tokens, |
| 2100 | + num_scheduled_tokens_np) |
| 2101 | + output.kv_connector_output = kv_connector_output |
| 2102 | + return output |
2102 | 2103 |
|
2103 | 2104 | sample_hidden_states = hidden_states[logits_indices]
|
2104 | 2105 | logits = self.model.compute_logits(sample_hidden_states, None)
|
2105 |
| - if broadcast_pp_output: |
2106 |
| - model_output_broadcast_data = { |
2107 |
| - "logits": logits.contiguous(), |
2108 |
| - } if logits is not None else {} |
| 2106 | + else: |
| 2107 | + # Rare case. |
| 2108 | + assert not self.is_pooling_model |
| 2109 | + |
| 2110 | + if not get_pp_group().is_last_rank: |
| 2111 | + get_pp_group().send_tensor_dict( |
| 2112 | + hidden_states.tensors, all_gather_group=get_tp_group()) |
| 2113 | + logits = None |
| 2114 | + else: |
| 2115 | + sample_hidden_states = hidden_states[logits_indices] |
| 2116 | + logits = self.model.compute_logits(sample_hidden_states, |
| 2117 | + None) |
| 2118 | + |
| 2119 | + model_output_broadcast_data = {} |
| 2120 | + if logits is not None: |
| 2121 | + model_output_broadcast_data["logits"] = logits.contiguous() |
| 2122 | + |
2109 | 2123 | model_output_broadcast_data = get_pp_group(
|
2110 | 2124 | ).broadcast_tensor_dict(model_output_broadcast_data,
|
2111 | 2125 | src=len(get_pp_group().ranks) - 1)
|
|
0 commit comments