Optimized topk for topk=1 (Llama-4) (#16512)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
802329dee9
commit
bd6028d6b0
@ -37,7 +37,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from .llama import LlamaForCausalLM, LlamaMLP, LlamaModel
|
||||
from .utils import (AutoWeightsLoader, extract_layer_index,
|
||||
from .utils import (AutoWeightsLoader, extract_layer_index, fast_topk,
|
||||
is_pp_missing_parameter)
|
||||
|
||||
|
||||
@ -50,7 +50,7 @@ class Llama4MoE(nn.Module):
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
router_scores, router_indices = torch.topk(gating_output, topk, dim=-1)
|
||||
router_scores, router_indices = fast_topk(gating_output, topk, dim=-1)
|
||||
router_scores = torch.sigmoid(router_scores.float()).to(
|
||||
hidden_states.dtype)
|
||||
return (router_scores, router_indices.to(torch.int32))
|
||||
|
@ -703,3 +703,12 @@ def cast_overflow_tensors(
|
||||
clamp_value = torch.finfo(tensors.dtype).max - offset
|
||||
tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value)
|
||||
return tensors
|
||||
|
||||
|
||||
def fast_topk(values, topk, dim):
|
||||
if topk == 1:
|
||||
# Use max along the specified dimension to get both value and index
|
||||
return torch.max(values, dim=dim, keepdim=True)
|
||||
else:
|
||||
# Use topk for efficiency with larger k values
|
||||
return torch.topk(values, topk, dim=dim)
|
||||
|
Loading…
x
Reference in New Issue
Block a user