[Bugfix] Fix DeepSeek MTP crash when using TP1ModelRunner with CUDA graph due to shape mismatch (#14237)

Signed-off-by: pyc96 <pychen96@gmail.com>
This commit is contained in:
pyc96 2025-03-05 14:22:40 -08:00 committed by GitHub
parent 53ea6ad830
commit 1e3e76b6cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -302,6 +302,11 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
outputs.append(output)
if self.return_hidden_states and is_fallback:
if use_cuda_graph:
indices = model_input.sampling_metadata\
.selected_token_indices
output.hidden_states = hidden_states[:len(indices)]
else:
output.hidden_states = hidden_states
if model_input.attn_metadata.num_prefills == 0 \