Separate base model from TransformersModel (#15467)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Harry Mellor 2025-03-26 10:13:38 +00:00 committed by GitHub
parent 4ec2cee000
commit cf5c8f1686
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 110 additions and 59 deletions

View File

@ -57,10 +57,10 @@ llm = LLM(model=..., task="generate") # Name or path of your model
llm.apply_model(lambda model: print(type(model))) llm.apply_model(lambda model: print(type(model)))
``` ```
If it is `TransformersModel` then it means it's based on Transformers! If it is `TransformersForCausalLM` then it means it's based on Transformers!
:::{tip} :::{tip}
You can force the use of `TransformersModel` by setting `model_impl="transformers"` for <project:#offline-inference> or `--model-impl transformers` for the <project:#openai-compatible-server>. You can force the use of `TransformersForCausalLM` by setting `model_impl="transformers"` for <project:#offline-inference> or `--model-impl transformers` for the <project:#openai-compatible-server>.
::: :::
:::{note} :::{note}
@ -119,7 +119,7 @@ Here is what happens in the background:
1. The config is loaded 1. The config is loaded
2. `MyModel` Python class is loaded from the `auto_map`, and we check that the model `_supports_attention_backend`. 2. `MyModel` Python class is loaded from the `auto_map`, and we check that the model `_supports_attention_backend`.
3. The `TransformersModel` backend is used. See <gh-file:vllm/model_executor/models/transformers.py>, which leverage `self.config._attn_implementation = "vllm"`, thus the need to use `ALL_ATTENTION_FUNCTION`. 3. The `TransformersForCausalLM` backend is used. See <gh-file:vllm/model_executor/models/transformers.py>, which leverage `self.config._attn_implementation = "vllm"`, thus the need to use `ALL_ATTENTION_FUNCTION`.
To make your model compatible with tensor parallel, it needs: To make your model compatible with tensor parallel, it needs:

View File

@ -175,7 +175,7 @@ TEXT_GENERATION_MODELS = {
"inceptionai/jais-13b-chat": PPTestSettings.fast(), "inceptionai/jais-13b-chat": PPTestSettings.fast(),
"ai21labs/Jamba-tiny-dev": PPTestSettings.fast(), "ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
"meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(), "meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(),
# Tests TransformersModel # Tests TransformersForCausalLM
"ArthurZ/Ilama-3.2-1B": PPTestSettings.fast(), "ArthurZ/Ilama-3.2-1B": PPTestSettings.fast(),
"openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(), "openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(),
"openbmb/MiniCPM3-4B": PPTestSettings.fast(), "openbmb/MiniCPM3-4B": PPTestSettings.fast(),

View File

@ -319,7 +319,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
} }
_FALLBACK_MODEL = { _FALLBACK_MODEL = {
"TransformersModel": _HfExamplesInfo("ArthurZ/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501 "TransformersForCausalLM": _HfExamplesInfo("ArthurZ/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
} }
_EXAMPLE_MODELS = { _EXAMPLE_MODELS = {

View File

@ -45,7 +45,7 @@ def is_transformers_impl_compatible(
def resolve_transformers_fallback(model_config: ModelConfig, def resolve_transformers_fallback(model_config: ModelConfig,
architectures: list[str]): architectures: list[str]):
for i, arch in enumerate(architectures): for i, arch in enumerate(architectures):
if arch == "TransformersModel": if arch == "TransformersForCausalLM":
continue continue
auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map", auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
None) or dict() None) or dict()
@ -69,7 +69,7 @@ def resolve_transformers_fallback(model_config: ModelConfig,
raise ValueError( raise ValueError(
f"The Transformers implementation of {arch} is not " f"The Transformers implementation of {arch} is not "
"compatible with vLLM.") "compatible with vLLM.")
architectures[i] = "TransformersModel" architectures[i] = "TransformersForCausalLM"
if model_config.model_impl == ModelImpl.AUTO: if model_config.model_impl == ModelImpl.AUTO:
if not is_transformers_impl_compatible(arch, custom_model_module): if not is_transformers_impl_compatible(arch, custom_model_module):
raise ValueError( raise ValueError(
@ -80,7 +80,7 @@ def resolve_transformers_fallback(model_config: ModelConfig,
"%s has no vLLM implementation, falling back to Transformers " "%s has no vLLM implementation, falling back to Transformers "
"implementation. Some features may not be supported and " "implementation. Some features may not be supported and "
"performance may not be optimal.", arch) "performance may not be optimal.", arch)
architectures[i] = "TransformersModel" architectures[i] = "TransformersForCausalLM"
return architectures return architectures

View File

@ -201,7 +201,7 @@ _SPECULATIVE_DECODING_MODELS = {
} }
_FALLBACK_MODEL = { _FALLBACK_MODEL = {
"TransformersModel": ("transformers", "TransformersModel"), "TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
} }
# yapf: enable # yapf: enable
@ -425,7 +425,7 @@ class _ModelRegistry:
# make sure Transformers fallback are put at the last # make sure Transformers fallback are put at the last
if len(normalized_arch) != len(architectures): if len(normalized_arch) != len(architectures):
normalized_arch.append("TransformersModel") normalized_arch.append("TransformersForCausalLM")
return normalized_arch return normalized_arch
def inspect_model_cls( def inspect_model_cls(

View File

@ -43,7 +43,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
from .utils import (PPMissingLayer, is_pp_missing_parameter, from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, maybe_prefix) make_empty_intermediate_tensors_factory, maybe_prefix)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -110,13 +111,9 @@ def replace_linear_class(
) )
@support_torch_compile class TransformersModel(nn.Module):
class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"
] # TODO transformers will have a util to get it
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
logger.info("Using Transformers backend.") logger.info("Using Transformers backend.")
@ -134,9 +131,6 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.quant_config = quant_config self.quant_config = quant_config
self.vocab_size = model_config.get_vocab_size()
self.unpadded_vocab_size = model_config.get_vocab_size()
self.pp_group = get_pp_group() self.pp_group = get_pp_group()
self.pp_size = self.pp_group.world_size self.pp_size = self.pp_group.world_size
self.pp_rank = self.pp_group.rank_in_group self.pp_rank = self.pp_group.rank_in_group
@ -144,13 +138,15 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
# Use meta device to delay allocating GPU tensors # Use meta device to delay allocating GPU tensors
with torch.device("meta"): with torch.device("meta"):
# FIXME(Isotr0py): We need to refactor this part in the future to
# avoid registering an extra model layer, otherwise we will need a
# weights mapper to rename weights.
self.model: PreTrainedModel = AutoModel.from_config( self.model: PreTrainedModel = AutoModel.from_config(
config, config,
attn_implementation="vllm", attn_implementation="vllm",
torch_dtype=model_config.dtype, torch_dtype=model_config.dtype,
trust_remote_code=model_config.trust_remote_code, trust_remote_code=model_config.trust_remote_code,
) )
prefix = self.model.base_model_prefix
self.pipeline_parallel() self.pipeline_parallel()
self.tensor_parallel() self.tensor_parallel()
@ -168,32 +164,12 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
# Attention layers # Attention layers
self.attention_instances = self.create_attention_instances() self.attention_instances = self.create_attention_instances()
# Output embeddings
if not isinstance(getattr(self, "lm_head", None), PPMissingLayer):
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(
self.model.get_input_embeddings())
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
# Initialize buffers (e.g. rotary embedding inverse frequency) # Initialize buffers (e.g. rotary embedding inverse frequency)
self.init_buffers(self.model) self.init_buffers(self.model)
# Move remaining meta tensors to device (should happen last) # Move remaining meta tensors to device (should happen last)
self.meta_to_empty(self.model) self.meta_to_empty(self.model)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"], make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size)) config.hidden_size))
@ -248,9 +224,6 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
if not self.pp_group.is_last_rank: if not self.pp_group.is_last_rank:
setattr(self.model, name, PPMissingLayer()) setattr(self.model, name, PPMissingLayer())
if not self.pp_group.is_last_rank:
self.lm_head = PPMissingLayer()
def tensor_parallel(self): def tensor_parallel(self):
""" """
Apply the model's tensor parallelization plan. Apply the model's tensor parallelization plan.
@ -331,6 +304,9 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
for child in module.children(): for child in module.children():
self.meta_to_empty(child) self.meta_to_empty(child)
def get_input_embeddings(self) -> nn.Module:
return self.model.get_input_embeddings()
def forward( def forward(
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
@ -361,21 +337,6 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
return hidden_states return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
@ -393,3 +354,93 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
return loaded_params return loaded_params
@support_torch_compile
class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
SupportsPP):
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"
] # TODO transformers will have a util to get it
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: PretrainedConfig = vllm_config.model_config.hf_config
quant_config: QuantizationConfig = vllm_config.quant_config
self.config = config
self.model = TransformersModel(vllm_config=vllm_config, prefix=prefix)
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(
self.model.get_input_embeddings())
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
else:
self.lm_head = PPMissingLayer()
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
# FIXME(Isotr0py): Don't use any weights mapper for Transformers fallback,
# this makes thing complicated. We need to remove this mapper after refactor
# `TransformersModel` in the future.
@property
def hf_to_vllm_mapper(self):
prefix_mapper = {
name: "model." + name
for name, _ in self.model.model.named_children()
}
return WeightsMapper(
orig_to_new_substr={"model.": "model.model."},
orig_to_new_prefix=prefix_mapper,
)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, intermediate_tensors,
inputs_embeds)
return model_output
def compute_logits(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits
def sample(self, logits: torch.Tensor,
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)