[Bugfix] Fix DeepSeek MTP crash when using TP1ModelRunner with CUDA graph due to shape mismatch (#14237)

Signed-off-by: pyc96 <pychen96@gmail.com>
This commit is contained in:
pyc96 2025-03-05 14:22:40 -08:00 committed by GitHub
parent 53ea6ad830
commit 1e3e76b6cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -133,7 +133,7 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest): def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
"""Determines if draft_model_runner GPU multi-step can be used. """Determines if draft_model_runner GPU multi-step can be used.
Currently required conditions are: Currently required conditions are:
1. Only decodes 1. Only decodes
2. Only flash-attn 2. Only flash-attn
3. No LORA 3. No LORA
4. No prompt_adapter_config 4. No prompt_adapter_config
@ -171,12 +171,12 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
num_steps: int = 1, num_steps: int = 1,
**kwargs, **kwargs,
) -> Optional[List[SamplerOutput]]: ) -> Optional[List[SamplerOutput]]:
"""Executes num_steps forward passes with advacement of input tensors """Executes num_steps forward passes with advacement of input tensors
on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions. on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
Optimizations used: Optimizations used:
1. Input tensors are updated on the GPU directly 1. Input tensors are updated on the GPU directly
2. Skips GPU=>CPU serialization of sampler outputs (we don't need 2. Skips GPU=>CPU serialization of sampler outputs (we don't need
them since we do batch expansion later that uses GPU outputs) them since we do batch expansion later that uses GPU outputs)
3. Reuses sampling tensors (since we run only decodes and they have 3. Reuses sampling tensors (since we run only decodes and they have
a repeating sampling logic) a repeating sampling logic)
@ -302,7 +302,12 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
outputs.append(output) outputs.append(output)
if self.return_hidden_states and is_fallback: if self.return_hidden_states and is_fallback:
output.hidden_states = hidden_states if use_cuda_graph:
indices = model_input.sampling_metadata\
.selected_token_indices
output.hidden_states = hidden_states[:len(indices)]
else:
output.hidden_states = hidden_states
if model_input.attn_metadata.num_prefills == 0 \ if model_input.attn_metadata.num_prefills == 0 \
and self.indices_of_seq_with_bonus_tokens is not None: and self.indices_of_seq_with_bonus_tokens is not None: