[Bugfix] Refactor composite weight loading logic (#8656)

This commit is contained in:
Isotr0py 2024-09-22 12:33:27 +08:00 committed by GitHub
parent d66ac62854
commit 13d88d4137
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 70 additions and 61 deletions

View File

@ -4,7 +4,6 @@
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import itertools
import re
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
@ -33,8 +32,8 @@ from vllm.utils import is_list_of
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_num_patches)
from .interfaces import SupportsMultiModal
from .utils import (filter_weights, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings)
from .utils import (flatten_bn, group_weights_with_prefix,
init_vllm_registered_model, merge_multimodal_embeddings)
IMG_START = '<img>'
IMG_END = '</img>'
@ -518,21 +517,18 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
weights_group = group_weights_with_prefix(weights)
# load vision encoder
vit_weights = filter_weights(vit_weights, "vision_model")
self.vision_model.load_weights(vit_weights)
self.vision_model.load_weights(weights_group["vision_model"])
# load mlp projector
mlp_weights = filter_weights(mlp_weights, "mlp1")
mlp_params_dict = dict(self.mlp1.named_parameters())
for name, loaded_weight in mlp_weights:
for name, loaded_weight in weights_group["mlp1"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)
self.language_model.load_weights(weights_group["language_model"])

View File

@ -1,4 +1,3 @@
import itertools
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
@ -26,8 +25,8 @@ from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
input_processor_for_siglip)
from .utils import (filter_weights, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings)
from .utils import (flatten_bn, group_weights_with_prefix,
init_vllm_registered_model, merge_multimodal_embeddings)
class LlavaImagePixelInputs(TypedDict):
@ -393,21 +392,18 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
weights_group = group_weights_with_prefix(weights)
# load vision encoder
vit_weights = filter_weights(vit_weights, "vision_tower")
self.vision_tower.load_weights(vit_weights)
self.vision_tower.load_weights(weights_group["vision_tower"])
# load mlp projector
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in mlp_weights:
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)
self.language_model.load_weights(weights_group["language_model"])

View File

@ -1,4 +1,3 @@
import itertools
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
@ -30,8 +29,8 @@ from .llava import LlavaMultiModalProjector
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip)
from .utils import (filter_weights, flatten_bn, init_vllm_registered_model,
merge_multimodal_embeddings)
from .utils import (flatten_bn, group_weights_with_prefix,
init_vllm_registered_model, merge_multimodal_embeddings)
logger = init_logger(__name__)
@ -637,25 +636,21 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee(
weights, 4)
weights_group = group_weights_with_prefix(weights)
# load vision encoder
vit_weights = filter_weights(vit_weights, "vision_tower")
self.vision_tower.load_weights(vit_weights)
self.vision_tower.load_weights(weights_group["vision_tower"])
# load mlp projector
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in mlp_weights:
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load newline
newline_weights = filter_weights(newline_weights, "image_newline")
for name, loaded_weight in newline_weights:
for name, loaded_weight in weights_group["image_newline"]:
assert name == ""
param = self.image_newline
weight_loader = getattr(param, "weight_loader",
@ -663,5 +658,4 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
weight_loader(param, loaded_weight)
# load llm backbone
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)
self.language_model.load_weights(weights_group["language_model"])

View File

@ -1,4 +1,3 @@
import itertools
import math
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
@ -30,7 +29,7 @@ from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip)
from .utils import (filter_weights, init_vllm_registered_model,
from .utils import (group_weights_with_prefix, init_vllm_registered_model,
merge_multimodal_embeddings)
logger = init_logger(__name__)
@ -449,23 +448,19 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal):
return self.language_model.sample(logits, sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators
vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee(
weights, 4)
# prepare weight iterators for components
weights_group = group_weights_with_prefix(weights)
# load vision encoder
vit_weights = filter_weights(vit_weights, "vision_tower")
self.vision_tower.load_weights(vit_weights)
self.vision_tower.load_weights(weights_group["vision_tower"])
# load mlp projector
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in mlp_weights:
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)
self.language_model.load_weights(weights_group["language_model"])

View File

@ -1,4 +1,3 @@
import itertools
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
@ -23,7 +22,7 @@ from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
from .utils import filter_weights, merge_multimodal_embeddings
from .utils import group_weights_with_prefix, merge_multimodal_embeddings
logger = init_logger(__name__)
@ -286,21 +285,18 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
weights_group = group_weights_with_prefix(weights)
# load vision tower
vit_weights = filter_weights(vit_weights, "vision_tower")
self.vision_tower.load_weights(vit_weights)
self.vision_tower.load_weights(weights_group["vision_tower"])
# load mlp projector
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
for name, loaded_weight in mlp_weights:
for name, loaded_weight in weights_group["multi_modal_projector"]:
param = mlp_params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# load llm backbone
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)
self.language_model.load_weights(weights_group["language_model"])

View File

@ -1,7 +1,6 @@
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
import itertools
import math
from array import array
from functools import lru_cache
@ -29,7 +28,8 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.utils import (filter_weights, flatten_bn,
from vllm.model_executor.models.utils import (flatten_bn,
group_weights_with_prefix,
init_vllm_registered_model,
merge_multimodal_embeddings)
from vllm.model_executor.sampling_metadata import SamplingMetadata
@ -467,11 +467,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
projector_weights, llm_weights = itertools.tee(weights, 2)
weights_group = group_weights_with_prefix(weights)
# load projector weights
projector_weights = filter_weights(projector_weights,
"multi_modal_projector")
projector_weights = weights_group["multi_modal_projector"]
projector_params_dict = dict(
self.multi_modal_projector.named_parameters())
for name, loaded_weight in projector_weights:
@ -481,5 +480,4 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
weight_loader(param, loaded_weight)
# load llm backbone
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)
self.language_model.load_weights(weights_group["language_model"])

View File

@ -1,3 +1,5 @@
import itertools
from collections import UserDict
from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple,
Union, overload)
@ -16,7 +18,23 @@ from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available
def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
class WeightsGroup(UserDict):
"""
Wraps grouped weights dictionary for a more informative error message
when attempting to access a weight component that does not exist.
"""
def __getitem__(self, key: str) -> int:
try:
return super().__getitem__(key)
except KeyError as exc:
msg = (f"There is no weights named with the prefix: {key}. "
f"Available prefix: {set(self.keys())}")
raise KeyError(msg) from exc
def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]],
prefix: str) -> Iterable[Tuple[str, torch.Tensor]]:
"""
Helper function to load weights for inner vLLM models.
@ -30,6 +48,22 @@ def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
yield name, loaded_weight
def group_weights_with_prefix(
weights: Iterable[Tuple[str, torch.Tensor]]
) -> Dict[str, Iterable[Tuple[str, torch.Tensor]]]:
"""
Helper function to group weights with prefix
"""
init_weights, repeated_weights = itertools.tee(weights, 2)
weights_prefix = {name.split(".")[0] for name, _ in init_weights}
repeated_weights = itertools.tee(repeated_weights, len(weights_prefix))
return WeightsGroup({
prefix: filter_weights(component, prefix)
for component, prefix in zip(repeated_weights, weights_prefix)
})
def init_vllm_registered_model(
hf_config: PretrainedConfig,
cache_config: Optional[CacheConfig],