[Model]: Add transformers backend support (#11330)

# Adds support for `transformers` as a backend

Following https://github.com/huggingface/transformers/pull/35235, a
bunch of models should already be supported, we are ramping up support
for more models.

Thanks @Isotr0py for the TP support, and @hmellor for his help as well!
This includes: 
- `trust_remote_code=True` support: any model on the hub, if it
implements attention the correct way can be natively supported!!
- tensor parallel support

---------

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Isotr0py <41363108+Isotr0py@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Arthur 2025-02-03 14:30:38 +01:00 committed by GitHub
parent 1298a400e8
commit a1a2aaadb9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 528 additions and 9 deletions

View File

@ -349,6 +349,7 @@ steps:
- vllm/
- tests/models
commands:
- pytest -v -s models/test_transformers.py
- pytest -v -s models/test_registry.py
- pytest -v -s models/test_initialization.py
@ -485,6 +486,7 @@ steps:
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
# Avoid importing model tests that cause CUDA reinitialization error
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)'

View File

@ -40,6 +40,82 @@ If vLLM successfully returns text (for generative models) or hidden states (for
Otherwise, please refer to [Adding a New Model](#new-model) for instructions on how to implement your model in vLLM.
Alternatively, you can [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) to request vLLM support.
### Transformers fallback
After the merge of <gh-pr:11330>, `vllm` can fallback to models that are available in `transformers`. This does not work for all models for now, but most decoder language models are supported, and vision language model support is planned!
To check if the backend is `transformers`, you can simply do this:
```python
from vllm import LLM
llm = LLM(model=..., task="generate") # Name or path of your model
llm.apply_model(lambda model: print(model.__class__))
```
If it is `TransformersModel` then it means it's based on `transformers`!
#### Supported features
##### LORA and quantization
Both are not supported yet! Make sure to open an issue and we'll work on this together with the `transformers` team!
Usually `transformers` model load weights via the `load_adapters` API, that depends on PEFT. We need to work a bit to either use this api (for now this would result in some weights not being marked as loaded) or replace modules accordingly.
Hints as to how this would look like:
```python
class TransformersModel(nn.Module, SupportsLoRA):
def __init__(*):
...
self.model.load_adapter(vllm_config.load_config.model_loader_extra_config["qlora_adapter_name_or_path"])
```
Blocker is that you need to specify supported lora layers, when we would ideally want to load whatever is inside the checkpoint!
##### Remote code
This fallback also means that any model on the hub that can be used in `transformers` with `trust_remote_code=True` that correctly implements attention can be used in production!
```python
from vllm import LLM
llm = LLM(model=..., task="generate", trust_remote_code=True) # Name or path of your model
llm.apply_model(lambda model: print(model.__class__))
```
A model just needs the following two things:
```python
from transformers import PreTrainedModel
from torch import nn
class MyAttention(nn.Module):
def forward(self, hidden_states, **kwargs): # <- kwargs are required
...
attention_interface = attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
**kwargs,
)
...
class MyModel(PreTrainedModel):
_supports_attention_backend = True
```
Here is what happens in the background:
1. The config is loaded
2. `MyModel` python class is loaded from the `auto_map`, and we check that the model `_supports_attention_backend`.
3. The `TransformersModel` backend is used. See `/model_executors/models/transformers`, which leverage `self.config._attn_implementation = "vllm"`, thus the need to use `ALL_ATTENTION_FUNCTION`.
That's it!
### ModelScope
To use models from [ModelScope](https://www.modelscope.cn) instead of HuggingFace Hub, set an environment variable:

View File

@ -5,7 +5,7 @@ requests >= 2.26.0
tqdm
blake3
py-cpuinfo
transformers >= 4.48.2 # Required for Bamba.
transformers >= 4.48.2 # Required for Bamba model and Transformers backend.
tokenizers >= 0.19.1 # Required for Llama 3.
protobuf # Required by LlamaTokenizer.
fastapi >= 0.107.0, < 0.113.0; python_version < '3.9'

View File

@ -281,12 +281,17 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
speculative_model="ibm-fms/llama-160m-accelerator"), # noqa: E501
}
_FALLBACK_MODEL = {
"TransformersModel": _HfExamplesInfo("ArthurZ/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
}
_EXAMPLE_MODELS = {
**_TEXT_GENERATION_EXAMPLE_MODELS,
**_EMBEDDING_EXAMPLE_MODELS,
**_CROSS_ENCODER_EXAMPLE_MODELS,
**_MULTIMODAL_EXAMPLE_MODELS,
**_SPECULATIVE_DECODING_EXAMPLE_MODELS,
**_FALLBACK_MODEL,
}

View File

@ -15,7 +15,9 @@ def test_plugin(dummy_opt_path):
os.environ["VLLM_PLUGINS"] = ""
with pytest.raises(Exception) as excinfo:
LLM(model=dummy_opt_path, load_format="dummy")
assert "are not supported for now" in str(excinfo.value)
error_msg = "has no vLLM implementation and " \
"the Transformers implementation is not compatible with vLLM."
assert (error_msg in str(excinfo.value))
@fork_new_process_for_each_test

View File

@ -0,0 +1,75 @@
"""Test the functionality of the Transformers backend.
Run `pytest tests/models/test_transformers.py`.
"""
from contextlib import nullcontext
from typing import Type
import pytest
from ..conftest import HfRunner, VllmRunner
from ..utils import multi_gpu_test
from .utils import check_logprobs_close
def check_implementation(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
example_prompts: list[str],
model: str,
**kwargs,
):
max_tokens = 32
num_logprobs = 5
with vllm_runner(model, **kwargs) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
with hf_runner(model) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize(
"model,model_impl",
[
("meta-llama/Llama-3.2-1B-Instruct", "transformers"),
("openai-community/gpt2", "transformers"),
("ArthurZ/Ilama-3.2-1B", "auto"), # CUSTOM CODE
("meta-llama/Llama-3.2-1B-Instruct", "auto"),
]) # trust_remote_code=True by default
def test_models(hf_runner, vllm_runner, example_prompts, model,
model_impl) -> None:
maybe_raises = nullcontext()
if model == "openai-community/gpt2" and model_impl == "transformers":
# Model is not backend compatible
maybe_raises = pytest.raises(
ValueError,
match="The Transformers implementation.*not compatible with vLLM")
with maybe_raises:
check_implementation(hf_runner,
vllm_runner,
example_prompts,
model,
model_impl=model_impl)
@multi_gpu_test(num_gpus=2)
def test_distributed(
hf_runner,
vllm_runner,
example_prompts,
):
kwargs = {"model_impl": "transformers", "tensor_parallel_size": 2}
check_implementation(hf_runner, vllm_runner, example_prompts,
"meta-llama/Llama-3.2-1B-Instruct", **kwargs)

View File

@ -83,6 +83,12 @@ class SupportsHash(Protocol):
...
class ModelImpl(str, enum.Enum):
AUTO = "auto"
VLLM = "vllm"
TRANSFORMERS = "transformers"
class ModelConfig:
"""Configuration for the model.
@ -167,6 +173,12 @@ class ModelConfig:
`logits_processors` extra completion argument. Defaults to None,
which allows no processors.
generation_config: Configuration parameter file for generation.
model_impl: Which implementation of the model to use:
"auto" will try to use the vLLM implementation if it exists and
fall back to the Transformers implementation if no vLLM
implementation is available.
"vllm" will use the vLLM model implementation.
"transformers" will use the Transformers model implementation.
override_generation_config: Override the generation config with the
given config.
"""
@ -230,6 +242,7 @@ class ModelConfig:
generation_config: Optional[str] = None,
enable_sleep_mode: bool = False,
override_generation_config: Optional[Dict[str, Any]] = None,
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
) -> None:
self.model = model
self.tokenizer = tokenizer
@ -241,6 +254,7 @@ class ModelConfig:
self.code_revision = code_revision
self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
self.model_impl = model_impl
if hf_overrides is None:
hf_overrides = {}

View File

@ -13,10 +13,10 @@ import vllm.envs as envs
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
DecodingConfig, DeviceConfig, HfOverrides,
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PoolerConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TaskOption, TokenizerPoolConfig,
VllmConfig)
ModelConfig, ModelImpl, ObservabilityConfig,
ParallelConfig, PoolerConfig, PromptAdapterConfig,
SchedulerConfig, SpeculativeConfig, TaskOption,
TokenizerPoolConfig, VllmConfig)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
@ -199,6 +199,7 @@ class EngineArgs:
generation_config: Optional[str] = None
override_generation_config: Optional[Dict[str, Any]] = None
enable_sleep_mode: bool = False
model_impl: str = "auto"
calculate_kv_scales: Optional[bool] = None
@ -378,6 +379,18 @@ class EngineArgs:
'qualified names that can be passed with the `logits_processors` '
'extra completion argument. Defaults to None, which allows no '
'processors.')
parser.add_argument(
'--model-impl',
type=str,
default=EngineArgs.model_impl,
choices=[f.value for f in ModelImpl],
help='Which implementation of the model to use.\n\n'
'* "auto" will try to use the vLLM implementation if it exists '
'and fall back to the Transformers implementation if no vLLM '
'implementation is available.\n'
'* "vllm" will use the vLLM model implementation.\n'
'* "transformers" will use the Transformers model '
'implementation.\n')
# Parallel arguments
parser.add_argument(
'--distributed-executor-backend',
@ -1017,6 +1030,7 @@ class EngineArgs:
generation_config=self.generation_config,
override_generation_config=self.override_generation_config,
enable_sleep_mode=self.enable_sleep_mode,
model_impl=self.model_impl,
)
def create_load_config(self) -> LoadConfig:

View File

@ -2,17 +2,22 @@
"""Utilities for selecting and loading models."""
import contextlib
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Type
from typing import Dict, List, Optional, Tuple, Type
import torch
import transformers
from torch import nn
from transformers.dynamic_module_utils import get_class_from_dynamic_module
from vllm.config import ModelConfig
from vllm.config import ModelConfig, ModelImpl
from vllm.logger import init_logger
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.adapters import (as_classification_model,
as_embedding_model,
as_reward_model)
logger = init_logger(__name__)
@contextlib.contextmanager
def set_default_torch_dtype(dtype: torch.dtype):
@ -23,6 +28,50 @@ def set_default_torch_dtype(dtype: torch.dtype):
torch.set_default_dtype(old_dtype)
def is_transformers_impl_compatible(
arch: str,
module: Optional[transformers.PreTrainedModel] = None) -> bool:
mod = module or getattr(transformers, arch, None)
if mod is None:
return False
if hasattr(mod, "supports_backend"):
return mod.is_backend_compatible()
else:
return mod._supports_flex_attn
def resolve_transformers_fallback(model_config: ModelConfig,
architectures: list[str]):
for i, arch in enumerate(architectures):
if arch == "TransformersModel":
continue
custom_module = None
auto_map = getattr(model_config.hf_config, "auto_map", None)
if auto_map is not None and "AutoModel" in auto_map:
custom_module = get_class_from_dynamic_module(
model_config.hf_config.auto_map["AutoModel"],
model_config.model)
# TODO(Isotr0py): Further clean up these raises.
# perhaps handled them in _ModelRegistry._raise_for_unsupported?
if model_config.model_impl == ModelImpl.TRANSFORMERS:
if not is_transformers_impl_compatible(arch, custom_module):
raise ValueError(
f"The Transformers implementation of {arch} is not "
"compatible with vLLM.")
architectures[i] = "TransformersModel"
if model_config.model_impl == ModelImpl.AUTO:
if not is_transformers_impl_compatible(arch, custom_module):
raise ValueError(
f"{arch} has no vLLM implementation and the Transformers "
"implementation is not compatible with vLLM.")
logger.warning(
"%s has no vLLM implementation, falling back to Transformers "
"implementation. Some features may not be supported and "
"performance may not be optimal.", arch)
architectures[i] = "TransformersModel"
return architectures
def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])
@ -38,6 +87,14 @@ def get_model_architecture(
and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"]
vllm_supported_archs = ModelRegistry.get_supported_archs()
is_vllm_supported = any(arch in vllm_supported_archs
for arch in architectures)
if (not is_vllm_supported
or model_config.model_impl == ModelImpl.TRANSFORMERS):
architectures = resolve_transformers_fallback(model_config,
architectures)
model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
if model_config.task == "embed":
model_cls = as_embedding_model(model_cls)

View File

@ -184,6 +184,10 @@ _SPECULATIVE_DECODING_MODELS = {
"MedusaModel": ("medusa", "Medusa"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
}
_FALLBACK_MODEL = {
"TransformersModel": ("transformers", "TransformersModel"),
}
# yapf: enable
_VLLM_MODELS = {
@ -192,6 +196,7 @@ _VLLM_MODELS = {
**_CROSS_ENCODER_MODELS,
**_MULTIMODAL_MODELS,
**_SPECULATIVE_DECODING_MODELS,
**_FALLBACK_MODEL,
}
@ -378,7 +383,12 @@ class _ModelRegistry:
if not architectures:
logger.warning("No model architectures are specified")
return architectures
normalized_arch = []
for model in architectures:
if model not in self.models:
model = "TransformersModel"
normalized_arch.append(model)
return normalized_arch
def inspect_model_cls(
self,

View File

@ -0,0 +1,264 @@
# Copyright 2024 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrapper around `transformers` models"""
import re
from typing import Iterable, List, Optional, Set, Tuple, Union
import torch
from torch import nn
from transformers import AutoModel, PreTrainedModel
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from vllm.attention import Attention, AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.utils import divide
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .utils import maybe_prefix
logger = init_logger(__name__)
def vllm_flash_attention_forward(
# Transformers args
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor,
# Transformers kwargs
scaling: float = None,
# vLLM kwargs
attn_metadata: AttentionMetadata = None,
attention_instances: list[Attention] = None,
**kwargs):
self_attn = attention_instances[module.layer_idx]
if scaling is not None:
self_attn.impl.scale = float(scaling)
hidden = query.shape[-2]
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
return self_attn.forward(
query,
key,
value,
kv_cache=None, # argument not used
attn_metadata=attn_metadata), None
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
# Linear Layer that is compatible with transformers internal forward
# TODO: This is a temporary solution, we should find a better way to integrate
class HFColumnParallelLinear(ColumnParallelLinear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return super().forward(input)[0]
class HFRowParallelLinear(RowParallelLinear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return super().forward(input)[0]
def replace_tp_linear_class(orig_module: nn.Linear,
style: str,
quant_config=None):
"""
In model configurations, we use a neutral type (string) to specify parallel
styles, here we use it to translate nn.Linear into vllm-style tp Linear.
Quant config is not supported yet
"""
if not isinstance(style, str):
raise ValueError(
f"Unsupported parallel style type {type(style)}, expected str")
input_size = orig_module.in_features
output_size = orig_module.out_features
bias = orig_module.bias is not None
if style == "colwise":
return HFColumnParallelLinear(
input_size,
output_size,
bias,
)
elif style == "rowwise":
return HFRowParallelLinear(
input_size,
output_size,
bias,
)
# We don't consider colwise_rep since it's used in lm_head
else:
raise ValueError(f"Unsupported parallel style value: {style}")
class TransformersModel(nn.Module):
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"
] # TODO transformers will have a util to get it
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
logger.info("Using Transformers backend.")
self.vllm_config = vllm_config
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.quant_config = quant_config
self.config = config
self.vocab_size = config.vocab_size
self.unpadded_vocab_size = config.vocab_size
self.model: PreTrainedModel = AutoModel.from_config(
self.config,
attn_implementation="vllm",
torch_dtype=vllm_config.model_config.dtype,
trust_remote_code=vllm_config.model_config.trust_remote_code,
)
prefix = self.model.base_model_prefix
# MLP modifications
self.tensor_parallelize(self.model)
# Attention modifications (assumes 1 attention op per hidden layer)
tp_size = get_tensor_model_parallel_world_size()
self.attention_instances = [
Attention(
num_heads=divide(config.num_attention_heads, tp_size),
head_size=config.head_dim,
# NOTE: We use Llama scale as default, if it's set by
# Transformers, it's updated in vllm_flash_attention_forward
scale=config.head_dim**-0.5,
num_kv_heads=divide(config.num_key_value_heads, tp_size),
cache_config=cache_config,
quant_config=None,
prefix=f"{i}.attn") for i in range(config.num_hidden_layers)
]
# Model modifications
self.replace_vocab_embed_class(self.model)
# ForCausalLM modifications
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=None,
prefix=maybe_prefix(prefix, "lm_head"))
if config.tie_word_embeddings:
self.lm_head.weight = self.model.get_input_embeddings().weight
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.sampler = get_sampler()
def log_replacement(self, name: str, old_module: nn.Module,
new_module: nn.Module):
logger.debug("%s: %s -> %s", name, old_module, new_module)
def tensor_parallelize(self, module: nn.Module, prefix: str = ""):
if (self.config.base_model_tp_plan is None
and self.vllm_config.parallel_config.tensor_parallel_size > 1):
raise ValueError(
"Trying to run tensor parallelization but the model does not "
"support it yet!")
for child_name, child_module in module.named_children():
qual_name = prefix + child_name
for pattern, style in self.config.base_model_tp_plan.items():
if re.match(pattern, qual_name) and isinstance(
child_module, nn.Linear):
new_module = replace_tp_linear_class(
child_module, style, self.quant_config)
setattr(module, child_name, new_module)
self.log_replacement(qual_name, child_module, new_module)
else:
self.tensor_parallelize(child_module, prefix=f"{qual_name}.")
def replace_vocab_embed_class(self, module: nn.Module):
# Use native set input embeddings
new_module = VocabParallelEmbedding(
self.vocab_size,
self.config.hidden_size,
org_num_embeddings=self.config.vocab_size,
quant_config=None,
)
self.log_replacement("input embedding",
self.model.get_input_embeddings(), new_module)
self.model.set_input_embeddings(new_module)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor], # argument not used
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(
input_ids[None, ...],
use_cache=False,
position_ids=positions[None, ...],
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
attention_instances=self.attention_instances,
return_dict=False)[0][0, ...] # we remove batch dimension for now
return model_output
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if name not in params_dict:
name = f"{self.model.base_model_prefix}.{name}"
if name in params_dict:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params