[Model][Bugfix]: correct Aria model output (#12309)
Signed-off-by: xffxff <1247714429@qq.com>
This commit is contained in:
parent
cd7b6f0857
commit
528dbcac7d
@ -28,9 +28,10 @@ def run_aria(question: str, modality: str):
|
|||||||
llm = LLM(model=model_name,
|
llm = LLM(model=model_name,
|
||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
max_num_seqs=2,
|
max_num_seqs=2,
|
||||||
|
dtype="bfloat16",
|
||||||
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache)
|
||||||
|
|
||||||
prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>\n{question}"
|
prompt = (f"<|im_start|>user\n<fim_prefix><|img|><fim_suffix>{question}"
|
||||||
"<|im_end|>\n<|im_start|>assistant\n")
|
"<|im_end|>\n<|im_start|>assistant\n")
|
||||||
|
|
||||||
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
|
stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519]
|
||||||
|
@ -30,6 +30,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
|
from .idefics2_vision_model import Idefics2VisionConfig
|
||||||
from .idefics2_vision_model import (
|
from .idefics2_vision_model import (
|
||||||
Idefics2VisionTransformer as Idefics3VisionTransformer)
|
Idefics2VisionTransformer as Idefics3VisionTransformer)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
@ -50,6 +51,53 @@ class AriaImagePixelInputs(TypedDict):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class AriaVisionTransformer(Idefics3VisionTransformer):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Idefics2VisionConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__(config, quant_config, prefix)
|
||||||
|
# Unlike Idefics3VisionTransformer which uses LayerNorm after the
|
||||||
|
# final layer, Aria omits this normalization, so we replace it with an
|
||||||
|
# Identity layer
|
||||||
|
self.post_layernorm = nn.Identity()
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: Set[str] = set()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
|
||||||
|
# NOTE: post_layernorm is not used in Aria
|
||||||
|
if "post_layernorm" in name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
class AriaProjectorMLP(nn.Module):
|
class AriaProjectorMLP(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -228,8 +276,10 @@ class AriaTextMoELayer(nn.Module):
|
|||||||
router_output = torch.nn.functional.linear(hidden_states,
|
router_output = torch.nn.functional.linear(hidden_states,
|
||||||
self.router_weight)
|
self.router_weight)
|
||||||
|
|
||||||
|
hidden_states_copy = hidden_states.clone()
|
||||||
|
# NOTE: hidden_states will be modified inplace by `FusedMoE`
|
||||||
sparse_expert_output = self.experts(hidden_states, router_output)
|
sparse_expert_output = self.experts(hidden_states, router_output)
|
||||||
shared_expert_output = self.shared_experts(hidden_states)
|
shared_expert_output = self.shared_experts(hidden_states_copy)
|
||||||
|
|
||||||
return sparse_expert_output + shared_expert_output
|
return sparse_expert_output + shared_expert_output
|
||||||
|
|
||||||
@ -445,7 +495,7 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.vision_tower = Idefics3VisionTransformer(
|
self.vision_tower = AriaVisionTransformer(
|
||||||
config.vision_config,
|
config.vision_config,
|
||||||
quant_config,
|
quant_config,
|
||||||
prefix=f"{prefix}.vision_tower",
|
prefix=f"{prefix}.vision_tower",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user