[Llama4] Enable attention temperature tuning by default for long context (>32k) (#16439)

Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
Yong Hoon Shin 2025-04-10 21:26:07 -07:00 committed by GitHub
parent d544d141ec
commit 99ef59cf7f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -467,11 +467,15 @@ class Llama4ForCausalLM(LlamaForCausalLM):
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Update temperature tuning config from generation config
# update temperature tuning config from generation config
gen_config = vllm_config.model_config.try_get_generation_config()
gen_config.update(vllm_config.model_config.override_generation_config)
# enable temperature tuning by default when max_model_len > 32K
default_attn_temperature_tuning = \
vllm_config.model_config.max_model_len > 32768
vllm_config.model_config.hf_config.attn_temperature_tuning \
= gen_config.get("attn_temperature_tuning", False)
= gen_config.get(
"attn_temperature_tuning", default_attn_temperature_tuning)
super().__init__(vllm_config=vllm_config,
prefix=prefix,