[Bugfix] Refactor composite weight loading logic (#8656)
This commit is contained in:
parent
d66ac62854
commit
13d88d4137
@ -4,7 +4,6 @@
|
|||||||
# Copyright (c) 2023 OpenGVLab
|
# Copyright (c) 2023 OpenGVLab
|
||||||
# Licensed under The MIT License [see LICENSE for details]
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
import itertools
|
|
||||||
import re
|
import re
|
||||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||||
TypedDict, Union)
|
TypedDict, Union)
|
||||||
@ -33,8 +32,8 @@ from vllm.utils import is_list_of
|
|||||||
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
||||||
get_clip_num_patches)
|
get_clip_num_patches)
|
||||||
from .interfaces import SupportsMultiModal
|
from .interfaces import SupportsMultiModal
|
||||||
from .utils import (filter_weights, flatten_bn, init_vllm_registered_model,
|
from .utils import (flatten_bn, group_weights_with_prefix,
|
||||||
merge_multimodal_embeddings)
|
init_vllm_registered_model, merge_multimodal_embeddings)
|
||||||
|
|
||||||
IMG_START = '<img>'
|
IMG_START = '<img>'
|
||||||
IMG_END = '</img>'
|
IMG_END = '</img>'
|
||||||
@ -518,21 +517,18 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
# prepare weight iterators for components
|
# prepare weight iterators for components
|
||||||
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
|
weights_group = group_weights_with_prefix(weights)
|
||||||
|
|
||||||
# load vision encoder
|
# load vision encoder
|
||||||
vit_weights = filter_weights(vit_weights, "vision_model")
|
self.vision_model.load_weights(weights_group["vision_model"])
|
||||||
self.vision_model.load_weights(vit_weights)
|
|
||||||
|
|
||||||
# load mlp projector
|
# load mlp projector
|
||||||
mlp_weights = filter_weights(mlp_weights, "mlp1")
|
|
||||||
mlp_params_dict = dict(self.mlp1.named_parameters())
|
mlp_params_dict = dict(self.mlp1.named_parameters())
|
||||||
for name, loaded_weight in mlp_weights:
|
for name, loaded_weight in weights_group["mlp1"]:
|
||||||
param = mlp_params_dict[name]
|
param = mlp_params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
# load llm backbone
|
# load llm backbone
|
||||||
llm_weights = filter_weights(llm_weights, "language_model")
|
self.language_model.load_weights(weights_group["language_model"])
|
||||||
self.language_model.load_weights(llm_weights)
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import itertools
|
|
||||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||||
TypedDict, Union)
|
TypedDict, Union)
|
||||||
|
|
||||||
@ -26,8 +25,8 @@ from .interfaces import SupportsMultiModal
|
|||||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||||
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
|
dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
|
||||||
input_processor_for_siglip)
|
input_processor_for_siglip)
|
||||||
from .utils import (filter_weights, flatten_bn, init_vllm_registered_model,
|
from .utils import (flatten_bn, group_weights_with_prefix,
|
||||||
merge_multimodal_embeddings)
|
init_vllm_registered_model, merge_multimodal_embeddings)
|
||||||
|
|
||||||
|
|
||||||
class LlavaImagePixelInputs(TypedDict):
|
class LlavaImagePixelInputs(TypedDict):
|
||||||
@ -393,21 +392,18 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
# prepare weight iterators for components
|
# prepare weight iterators for components
|
||||||
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
|
weights_group = group_weights_with_prefix(weights)
|
||||||
|
|
||||||
# load vision encoder
|
# load vision encoder
|
||||||
vit_weights = filter_weights(vit_weights, "vision_tower")
|
self.vision_tower.load_weights(weights_group["vision_tower"])
|
||||||
self.vision_tower.load_weights(vit_weights)
|
|
||||||
|
|
||||||
# load mlp projector
|
# load mlp projector
|
||||||
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
|
|
||||||
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
|
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
|
||||||
for name, loaded_weight in mlp_weights:
|
for name, loaded_weight in weights_group["multi_modal_projector"]:
|
||||||
param = mlp_params_dict[name]
|
param = mlp_params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
# load llm backbone
|
# load llm backbone
|
||||||
llm_weights = filter_weights(llm_weights, "language_model")
|
self.language_model.load_weights(weights_group["language_model"])
|
||||||
self.language_model.load_weights(llm_weights)
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import itertools
|
|
||||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||||
TypedDict, Union)
|
TypedDict, Union)
|
||||||
|
|
||||||
@ -30,8 +29,8 @@ from .llava import LlavaMultiModalProjector
|
|||||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||||
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
|
dummy_seq_data_for_siglip, get_siglip_image_feature_size,
|
||||||
get_siglip_patch_grid_length, input_processor_for_siglip)
|
get_siglip_patch_grid_length, input_processor_for_siglip)
|
||||||
from .utils import (filter_weights, flatten_bn, init_vllm_registered_model,
|
from .utils import (flatten_bn, group_weights_with_prefix,
|
||||||
merge_multimodal_embeddings)
|
init_vllm_registered_model, merge_multimodal_embeddings)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -637,25 +636,21 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
# prepare weight iterators for components
|
# prepare weight iterators for components
|
||||||
vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee(
|
weights_group = group_weights_with_prefix(weights)
|
||||||
weights, 4)
|
|
||||||
|
|
||||||
# load vision encoder
|
# load vision encoder
|
||||||
vit_weights = filter_weights(vit_weights, "vision_tower")
|
self.vision_tower.load_weights(weights_group["vision_tower"])
|
||||||
self.vision_tower.load_weights(vit_weights)
|
|
||||||
|
|
||||||
# load mlp projector
|
# load mlp projector
|
||||||
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
|
|
||||||
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
|
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
|
||||||
for name, loaded_weight in mlp_weights:
|
for name, loaded_weight in weights_group["multi_modal_projector"]:
|
||||||
param = mlp_params_dict[name]
|
param = mlp_params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
# load newline
|
# load newline
|
||||||
newline_weights = filter_weights(newline_weights, "image_newline")
|
for name, loaded_weight in weights_group["image_newline"]:
|
||||||
for name, loaded_weight in newline_weights:
|
|
||||||
assert name == ""
|
assert name == ""
|
||||||
param = self.image_newline
|
param = self.image_newline
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
@ -663,5 +658,4 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
# load llm backbone
|
# load llm backbone
|
||||||
llm_weights = filter_weights(llm_weights, "language_model")
|
self.language_model.load_weights(weights_group["language_model"])
|
||||||
self.language_model.load_weights(llm_weights)
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import itertools
|
|
||||||
import math
|
import math
|
||||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||||
TypedDict, Union)
|
TypedDict, Union)
|
||||||
@ -30,7 +29,7 @@ from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
|
|||||||
from .interfaces import SupportsMultiModal
|
from .interfaces import SupportsMultiModal
|
||||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||||
dummy_seq_data_for_siglip)
|
dummy_seq_data_for_siglip)
|
||||||
from .utils import (filter_weights, init_vllm_registered_model,
|
from .utils import (group_weights_with_prefix, init_vllm_registered_model,
|
||||||
merge_multimodal_embeddings)
|
merge_multimodal_embeddings)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -449,23 +448,19 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
return self.language_model.sample(logits, sampling_metadata)
|
return self.language_model.sample(logits, sampling_metadata)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
# prepare weight iterators
|
# prepare weight iterators for components
|
||||||
vit_weights, mlp_weights, newline_weights, llm_weights = itertools.tee(
|
weights_group = group_weights_with_prefix(weights)
|
||||||
weights, 4)
|
|
||||||
|
|
||||||
# load vision encoder
|
# load vision encoder
|
||||||
vit_weights = filter_weights(vit_weights, "vision_tower")
|
self.vision_tower.load_weights(weights_group["vision_tower"])
|
||||||
self.vision_tower.load_weights(vit_weights)
|
|
||||||
|
|
||||||
# load mlp projector
|
# load mlp projector
|
||||||
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
|
|
||||||
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
|
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
|
||||||
for name, loaded_weight in mlp_weights:
|
for name, loaded_weight in weights_group["multi_modal_projector"]:
|
||||||
param = mlp_params_dict[name]
|
param = mlp_params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
# load llm backbone
|
# load llm backbone
|
||||||
llm_weights = filter_weights(llm_weights, "language_model")
|
self.language_model.load_weights(weights_group["language_model"])
|
||||||
self.language_model.load_weights(llm_weights)
|
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
import itertools
|
|
||||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||||
TypedDict, Union)
|
TypedDict, Union)
|
||||||
|
|
||||||
@ -23,7 +22,7 @@ from vllm.sequence import IntermediateTensors
|
|||||||
from .interfaces import SupportsMultiModal
|
from .interfaces import SupportsMultiModal
|
||||||
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
|
||||||
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
|
dummy_seq_data_for_siglip, get_max_siglip_image_tokens)
|
||||||
from .utils import filter_weights, merge_multimodal_embeddings
|
from .utils import group_weights_with_prefix, merge_multimodal_embeddings
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -286,21 +285,18 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
# prepare weight iterators for components
|
# prepare weight iterators for components
|
||||||
vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3)
|
weights_group = group_weights_with_prefix(weights)
|
||||||
|
|
||||||
# load vision tower
|
# load vision tower
|
||||||
vit_weights = filter_weights(vit_weights, "vision_tower")
|
self.vision_tower.load_weights(weights_group["vision_tower"])
|
||||||
self.vision_tower.load_weights(vit_weights)
|
|
||||||
|
|
||||||
# load mlp projector
|
# load mlp projector
|
||||||
mlp_weights = filter_weights(mlp_weights, "multi_modal_projector")
|
|
||||||
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
|
mlp_params_dict = dict(self.multi_modal_projector.named_parameters())
|
||||||
for name, loaded_weight in mlp_weights:
|
for name, loaded_weight in weights_group["multi_modal_projector"]:
|
||||||
param = mlp_params_dict[name]
|
param = mlp_params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
# load llm backbone
|
# load llm backbone
|
||||||
llm_weights = filter_weights(llm_weights, "language_model")
|
self.language_model.load_weights(weights_group["language_model"])
|
||||||
self.language_model.load_weights(llm_weights)
|
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
|
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
|
||||||
"""PyTorch Ultravox model."""
|
"""PyTorch Ultravox model."""
|
||||||
|
|
||||||
import itertools
|
|
||||||
import math
|
import math
|
||||||
from array import array
|
from array import array
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
@ -29,7 +28,8 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||||
from vllm.model_executor.models.utils import (filter_weights, flatten_bn,
|
from vllm.model_executor.models.utils import (flatten_bn,
|
||||||
|
group_weights_with_prefix,
|
||||||
init_vllm_registered_model,
|
init_vllm_registered_model,
|
||||||
merge_multimodal_embeddings)
|
merge_multimodal_embeddings)
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
@ -467,11 +467,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
# prepare weight iterators for components
|
# prepare weight iterators for components
|
||||||
projector_weights, llm_weights = itertools.tee(weights, 2)
|
weights_group = group_weights_with_prefix(weights)
|
||||||
|
|
||||||
# load projector weights
|
# load projector weights
|
||||||
projector_weights = filter_weights(projector_weights,
|
projector_weights = weights_group["multi_modal_projector"]
|
||||||
"multi_modal_projector")
|
|
||||||
projector_params_dict = dict(
|
projector_params_dict = dict(
|
||||||
self.multi_modal_projector.named_parameters())
|
self.multi_modal_projector.named_parameters())
|
||||||
for name, loaded_weight in projector_weights:
|
for name, loaded_weight in projector_weights:
|
||||||
@ -481,5 +480,4 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
|
|||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
# load llm backbone
|
# load llm backbone
|
||||||
llm_weights = filter_weights(llm_weights, "language_model")
|
self.language_model.load_weights(weights_group["language_model"])
|
||||||
self.language_model.load_weights(llm_weights)
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
import itertools
|
||||||
|
from collections import UserDict
|
||||||
from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple,
|
from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple,
|
||||||
Union, overload)
|
Union, overload)
|
||||||
|
|
||||||
@ -16,7 +18,23 @@ from vllm.sequence import IntermediateTensors
|
|||||||
from vllm.utils import is_pin_memory_available
|
from vllm.utils import is_pin_memory_available
|
||||||
|
|
||||||
|
|
||||||
def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
|
class WeightsGroup(UserDict):
|
||||||
|
"""
|
||||||
|
Wraps grouped weights dictionary for a more informative error message
|
||||||
|
when attempting to access a weight component that does not exist.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __getitem__(self, key: str) -> int:
|
||||||
|
try:
|
||||||
|
return super().__getitem__(key)
|
||||||
|
except KeyError as exc:
|
||||||
|
msg = (f"There is no weights named with the prefix: {key}. "
|
||||||
|
f"Available prefix: {set(self.keys())}")
|
||||||
|
raise KeyError(msg) from exc
|
||||||
|
|
||||||
|
|
||||||
|
def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]],
|
||||||
|
prefix: str) -> Iterable[Tuple[str, torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Helper function to load weights for inner vLLM models.
|
Helper function to load weights for inner vLLM models.
|
||||||
|
|
||||||
@ -30,6 +48,22 @@ def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
|
|||||||
yield name, loaded_weight
|
yield name, loaded_weight
|
||||||
|
|
||||||
|
|
||||||
|
def group_weights_with_prefix(
|
||||||
|
weights: Iterable[Tuple[str, torch.Tensor]]
|
||||||
|
) -> Dict[str, Iterable[Tuple[str, torch.Tensor]]]:
|
||||||
|
"""
|
||||||
|
Helper function to group weights with prefix
|
||||||
|
"""
|
||||||
|
init_weights, repeated_weights = itertools.tee(weights, 2)
|
||||||
|
weights_prefix = {name.split(".")[0] for name, _ in init_weights}
|
||||||
|
repeated_weights = itertools.tee(repeated_weights, len(weights_prefix))
|
||||||
|
|
||||||
|
return WeightsGroup({
|
||||||
|
prefix: filter_weights(component, prefix)
|
||||||
|
for component, prefix in zip(repeated_weights, weights_prefix)
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
def init_vllm_registered_model(
|
def init_vllm_registered_model(
|
||||||
hf_config: PretrainedConfig,
|
hf_config: PretrainedConfig,
|
||||||
cache_config: Optional[CacheConfig],
|
cache_config: Optional[CacheConfig],
|
||||||
|
Loading…
x
Reference in New Issue
Block a user