[Bugfix] Fix DeepSeek MTP crash when using TP1ModelRunner with CUDA graph due to shape mismatch (#14237)
Signed-off-by: pyc96 <pychen96@gmail.com>
This commit is contained in:
parent
53ea6ad830
commit
1e3e76b6cc
@ -302,6 +302,11 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
||||
outputs.append(output)
|
||||
|
||||
if self.return_hidden_states and is_fallback:
|
||||
if use_cuda_graph:
|
||||
indices = model_input.sampling_metadata\
|
||||
.selected_token_indices
|
||||
output.hidden_states = hidden_states[:len(indices)]
|
||||
else:
|
||||
output.hidden_states = hidden_states
|
||||
|
||||
if model_input.attn_metadata.num_prefills == 0 \
|
||||
|
Loading…
x
Reference in New Issue
Block a user