[Pixtral] Improve loading (#11040)

This commit is contained in:
Patrick von Platen 2024-12-10 07:09:32 +01:00 committed by GitHub
parent 980ad394a8
commit bc192a2b09
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,5 @@
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from functools import cached_property from functools import cached_property
from itertools import tee
from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union
import numpy import numpy
@ -359,38 +358,33 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
def is_vision_lang_adapter_weights(weight: Tuple[str, torch.Tensor]): def is_vision_lang_adapter_weights(weight: Tuple[str, torch.Tensor]):
return weight[0].startswith("vision_language_adapter") return weight[0].startswith("vision_language_adapter")
def is_vision_weights(weight: Tuple[str, torch.Tensor]): # Get references to parameters for direct loading
return is_vision_encoder_weights(
weight) or is_vision_lang_adapter_weights(weight)
llm_weights, vision_encoder_weights, vision_lang_adapter_weights = tee(
weights, 3)
# llm
llm_weights = filter(lambda x: not is_vision_weights(x), llm_weights)
self.language_model.load_weights(llm_weights)
# vision encoder
vision_encoder_weights = filter(is_vision_encoder_weights,
vision_encoder_weights)
vision_encoder_dict = dict(self.vision_encoder.named_parameters()) vision_encoder_dict = dict(self.vision_encoder.named_parameters())
for name, loaded_weight in vision_encoder_weights: vision_lang_adapter_dict = dict(
# cut 'vision_encoder.'
name = '.'.join(name.split(".")[1:])
param = vision_encoder_dict[name]
default_weight_loader(param, loaded_weight)
# adapter
vision_lang_adapter_weights = filter(is_vision_lang_adapter_weights,
vision_lang_adapter_weights)
vision_lang_adpter_dict = dict(
self.vision_language_adapter.named_parameters()) self.vision_language_adapter.named_parameters())
for name, loaded_weight in vision_lang_adapter_weights:
# cut 'vision_language_adapter.' def llm_weights_generator():
name = '.'.join(name.split(".")[1:]) # Single pass over weights
param = vision_lang_adpter_dict[name] for name, w in weights:
default_weight_loader(param, loaded_weight) if is_vision_encoder_weights((name, w)):
# Load vision encoder weights directly
trimmed_name = '.'.join(name.split(".")[1:])
param = vision_encoder_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
elif is_vision_lang_adapter_weights((name, w)):
# Load vision-language adapter weights directly
trimmed_name = '.'.join(name.split(".")[1:])
param = vision_lang_adapter_dict[trimmed_name]
with torch.no_grad():
default_weight_loader(param, w)
else:
# LLM weights: yield them to be loaded
# by language_model.load_weights
yield (name, w)
# Now we call the language model load with the generator
self.language_model.load_weights(llm_weights_generator())
# Vision encoder # Vision encoder