diff --git a/vllm/model_executor/models/llama4.py b/vllm/model_executor/models/llama4.py index 3dbf352a..8785e9dc 100644 --- a/vllm/model_executor/models/llama4.py +++ b/vllm/model_executor/models/llama4.py @@ -467,11 +467,15 @@ class Llama4ForCausalLM(LlamaForCausalLM): } def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - # Update temperature tuning config from generation config + # update temperature tuning config from generation config gen_config = vllm_config.model_config.try_get_generation_config() gen_config.update(vllm_config.model_config.override_generation_config) + # enable temperature tuning by default when max_model_len > 32K + default_attn_temperature_tuning = \ + vllm_config.model_config.max_model_len > 32768 vllm_config.model_config.hf_config.attn_temperature_tuning \ - = gen_config.get("attn_temperature_tuning", False) + = gen_config.get( + "attn_temperature_tuning", default_attn_temperature_tuning) super().__init__(vllm_config=vllm_config, prefix=prefix,