[Bugfix] Fix DeepSeek MTP crash when using TP1ModelRunner with CUDA graph due to shape mismatch (#14237)

Signed-off-by: pyc96 <pychen96@gmail.com>
This commit is contained in:
pyc96 2025-03-05 14:22:40 -08:00 committed by GitHub
parent 53ea6ad830
commit 1e3e76b6cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -302,7 +302,12 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
outputs.append(output) outputs.append(output)
if self.return_hidden_states and is_fallback: if self.return_hidden_states and is_fallback:
output.hidden_states = hidden_states if use_cuda_graph:
indices = model_input.sampling_metadata\
.selected_token_indices
output.hidden_states = hidden_states[:len(indices)]
else:
output.hidden_states = hidden_states
if model_input.attn_metadata.num_prefills == 0 \ if model_input.attn_metadata.num_prefills == 0 \
and self.indices_of_seq_with_bonus_tokens is not None: and self.indices_of_seq_with_bonus_tokens is not None: