[Bugfix] Fix DeepSeek MTP crash when using TP1ModelRunner with CUDA graph due to shape mismatch (#14237)
Signed-off-by: pyc96 <pychen96@gmail.com>
This commit is contained in:
parent
53ea6ad830
commit
1e3e76b6cc
@ -133,7 +133,7 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
|||||||
def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
|
def supports_gpu_multi_step(self, execute_model_req: ExecuteModelRequest):
|
||||||
"""Determines if draft_model_runner GPU multi-step can be used.
|
"""Determines if draft_model_runner GPU multi-step can be used.
|
||||||
Currently required conditions are:
|
Currently required conditions are:
|
||||||
1. Only decodes
|
1. Only decodes
|
||||||
2. Only flash-attn
|
2. Only flash-attn
|
||||||
3. No LORA
|
3. No LORA
|
||||||
4. No prompt_adapter_config
|
4. No prompt_adapter_config
|
||||||
@ -171,12 +171,12 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
|||||||
num_steps: int = 1,
|
num_steps: int = 1,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Optional[List[SamplerOutput]]:
|
) -> Optional[List[SamplerOutput]]:
|
||||||
"""Executes num_steps forward passes with advacement of input tensors
|
"""Executes num_steps forward passes with advacement of input tensors
|
||||||
on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
|
on the GPU. Look at supports_gpu_multi_step(..) for pre-conditions.
|
||||||
|
|
||||||
Optimizations used:
|
Optimizations used:
|
||||||
1. Input tensors are updated on the GPU directly
|
1. Input tensors are updated on the GPU directly
|
||||||
2. Skips GPU=>CPU serialization of sampler outputs (we don't need
|
2. Skips GPU=>CPU serialization of sampler outputs (we don't need
|
||||||
them since we do batch expansion later that uses GPU outputs)
|
them since we do batch expansion later that uses GPU outputs)
|
||||||
3. Reuses sampling tensors (since we run only decodes and they have
|
3. Reuses sampling tensors (since we run only decodes and they have
|
||||||
a repeating sampling logic)
|
a repeating sampling logic)
|
||||||
@ -302,7 +302,12 @@ class TP1DraftModelRunner(ModelRunnerWrapperBase):
|
|||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
|
||||||
if self.return_hidden_states and is_fallback:
|
if self.return_hidden_states and is_fallback:
|
||||||
output.hidden_states = hidden_states
|
if use_cuda_graph:
|
||||||
|
indices = model_input.sampling_metadata\
|
||||||
|
.selected_token_indices
|
||||||
|
output.hidden_states = hidden_states[:len(indices)]
|
||||||
|
else:
|
||||||
|
output.hidden_states = hidden_states
|
||||||
|
|
||||||
if model_input.attn_metadata.num_prefills == 0 \
|
if model_input.attn_metadata.num_prefills == 0 \
|
||||||
and self.indices_of_seq_with_bonus_tokens is not None:
|
and self.indices_of_seq_with_bonus_tokens is not None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user