[Llama4] Enable attention temperature tuning by default for long context (>32k) (#16439)
Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com> Co-authored-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
parent
d544d141ec
commit
99ef59cf7f
@ -467,11 +467,15 @@ class Llama4ForCausalLM(LlamaForCausalLM):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
# Update temperature tuning config from generation config
|
# update temperature tuning config from generation config
|
||||||
gen_config = vllm_config.model_config.try_get_generation_config()
|
gen_config = vllm_config.model_config.try_get_generation_config()
|
||||||
gen_config.update(vllm_config.model_config.override_generation_config)
|
gen_config.update(vllm_config.model_config.override_generation_config)
|
||||||
|
# enable temperature tuning by default when max_model_len > 32K
|
||||||
|
default_attn_temperature_tuning = \
|
||||||
|
vllm_config.model_config.max_model_len > 32768
|
||||||
vllm_config.model_config.hf_config.attn_temperature_tuning \
|
vllm_config.model_config.hf_config.attn_temperature_tuning \
|
||||||
= gen_config.get("attn_temperature_tuning", False)
|
= gen_config.get(
|
||||||
|
"attn_temperature_tuning", default_attn_temperature_tuning)
|
||||||
|
|
||||||
super().__init__(vllm_config=vllm_config,
|
super().__init__(vllm_config=vllm_config,
|
||||||
prefix=prefix,
|
prefix=prefix,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user