[Bugfix] fix pp for llama4 (#16746)

Signed-off-by: Lu Fang <fanglu@fb.com>
This commit is contained in:
Lucia Fang 2025-04-17 22:51:30 -07:00 committed by GitHub
parent aaec845f8e
commit e31045f95c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -672,9 +672,9 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
self.config,
None,
prefix=maybe_prefix(prefix, "multi_modal_projector"))
self.language_model = _initialize_model(
vllm_config=vllm_config.with_hf_config(config.text_config),
vllm_config=vllm_config.with_hf_config(config.text_config,
["LlamaForCausalLM"]),
prefix=maybe_prefix(prefix, "language_model"),
model_class=Llama4ForCausalLM,
)
@ -824,7 +824,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
# language_model is an Llama4ForCausalLM instance. We load it's
# using llama4's load_weights routine.
language_model_weights, other_weights = self.separate_weights(
weights, prefix="language_model.model.")
weights, prefix="language_model.")
loader = AutoWeightsLoader(self)
loaded_language_model_params = loader.load_weights(
language_model_weights)