[Pixtral] Improve loading (#11040)
This commit is contained in:
parent
980ad394a8
commit
bc192a2b09
@ -1,6 +1,5 @@
|
||||
from dataclasses import dataclass, fields
|
||||
from functools import cached_property
|
||||
from itertools import tee
|
||||
from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy
|
||||
@ -359,38 +358,33 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
def is_vision_lang_adapter_weights(weight: Tuple[str, torch.Tensor]):
|
||||
return weight[0].startswith("vision_language_adapter")
|
||||
|
||||
def is_vision_weights(weight: Tuple[str, torch.Tensor]):
|
||||
return is_vision_encoder_weights(
|
||||
weight) or is_vision_lang_adapter_weights(weight)
|
||||
|
||||
llm_weights, vision_encoder_weights, vision_lang_adapter_weights = tee(
|
||||
weights, 3)
|
||||
|
||||
# llm
|
||||
llm_weights = filter(lambda x: not is_vision_weights(x), llm_weights)
|
||||
self.language_model.load_weights(llm_weights)
|
||||
|
||||
# vision encoder
|
||||
vision_encoder_weights = filter(is_vision_encoder_weights,
|
||||
vision_encoder_weights)
|
||||
# Get references to parameters for direct loading
|
||||
vision_encoder_dict = dict(self.vision_encoder.named_parameters())
|
||||
for name, loaded_weight in vision_encoder_weights:
|
||||
# cut 'vision_encoder.'
|
||||
name = '.'.join(name.split(".")[1:])
|
||||
param = vision_encoder_dict[name]
|
||||
|
||||
default_weight_loader(param, loaded_weight)
|
||||
|
||||
# adapter
|
||||
vision_lang_adapter_weights = filter(is_vision_lang_adapter_weights,
|
||||
vision_lang_adapter_weights)
|
||||
vision_lang_adpter_dict = dict(
|
||||
vision_lang_adapter_dict = dict(
|
||||
self.vision_language_adapter.named_parameters())
|
||||
for name, loaded_weight in vision_lang_adapter_weights:
|
||||
# cut 'vision_language_adapter.'
|
||||
name = '.'.join(name.split(".")[1:])
|
||||
param = vision_lang_adpter_dict[name]
|
||||
default_weight_loader(param, loaded_weight)
|
||||
|
||||
def llm_weights_generator():
|
||||
# Single pass over weights
|
||||
for name, w in weights:
|
||||
if is_vision_encoder_weights((name, w)):
|
||||
# Load vision encoder weights directly
|
||||
trimmed_name = '.'.join(name.split(".")[1:])
|
||||
param = vision_encoder_dict[trimmed_name]
|
||||
with torch.no_grad():
|
||||
default_weight_loader(param, w)
|
||||
elif is_vision_lang_adapter_weights((name, w)):
|
||||
# Load vision-language adapter weights directly
|
||||
trimmed_name = '.'.join(name.split(".")[1:])
|
||||
param = vision_lang_adapter_dict[trimmed_name]
|
||||
with torch.no_grad():
|
||||
default_weight_loader(param, w)
|
||||
else:
|
||||
# LLM weights: yield them to be loaded
|
||||
# by language_model.load_weights
|
||||
yield (name, w)
|
||||
|
||||
# Now we call the language model load with the generator
|
||||
self.language_model.load_weights(llm_weights_generator())
|
||||
|
||||
|
||||
# Vision encoder
|
||||
|
Loading…
x
Reference in New Issue
Block a user