Separate base model from TransformersModel
(#15467)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Isotr0py <2037008807@qq.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
4ec2cee000
commit
cf5c8f1686
@ -57,10 +57,10 @@ llm = LLM(model=..., task="generate") # Name or path of your model
|
||||
llm.apply_model(lambda model: print(type(model)))
|
||||
```
|
||||
|
||||
If it is `TransformersModel` then it means it's based on Transformers!
|
||||
If it is `TransformersForCausalLM` then it means it's based on Transformers!
|
||||
|
||||
:::{tip}
|
||||
You can force the use of `TransformersModel` by setting `model_impl="transformers"` for <project:#offline-inference> or `--model-impl transformers` for the <project:#openai-compatible-server>.
|
||||
You can force the use of `TransformersForCausalLM` by setting `model_impl="transformers"` for <project:#offline-inference> or `--model-impl transformers` for the <project:#openai-compatible-server>.
|
||||
:::
|
||||
|
||||
:::{note}
|
||||
@ -119,7 +119,7 @@ Here is what happens in the background:
|
||||
|
||||
1. The config is loaded
|
||||
2. `MyModel` Python class is loaded from the `auto_map`, and we check that the model `_supports_attention_backend`.
|
||||
3. The `TransformersModel` backend is used. See <gh-file:vllm/model_executor/models/transformers.py>, which leverage `self.config._attn_implementation = "vllm"`, thus the need to use `ALL_ATTENTION_FUNCTION`.
|
||||
3. The `TransformersForCausalLM` backend is used. See <gh-file:vllm/model_executor/models/transformers.py>, which leverage `self.config._attn_implementation = "vllm"`, thus the need to use `ALL_ATTENTION_FUNCTION`.
|
||||
|
||||
To make your model compatible with tensor parallel, it needs:
|
||||
|
||||
|
@ -175,7 +175,7 @@ TEXT_GENERATION_MODELS = {
|
||||
"inceptionai/jais-13b-chat": PPTestSettings.fast(),
|
||||
"ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
|
||||
"meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(),
|
||||
# Tests TransformersModel
|
||||
# Tests TransformersForCausalLM
|
||||
"ArthurZ/Ilama-3.2-1B": PPTestSettings.fast(),
|
||||
"openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(),
|
||||
"openbmb/MiniCPM3-4B": PPTestSettings.fast(),
|
||||
|
@ -319,7 +319,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
}
|
||||
|
||||
_FALLBACK_MODEL = {
|
||||
"TransformersModel": _HfExamplesInfo("ArthurZ/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
|
||||
"TransformersForCausalLM": _HfExamplesInfo("ArthurZ/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
|
||||
}
|
||||
|
||||
_EXAMPLE_MODELS = {
|
||||
|
@ -45,7 +45,7 @@ def is_transformers_impl_compatible(
|
||||
def resolve_transformers_fallback(model_config: ModelConfig,
|
||||
architectures: list[str]):
|
||||
for i, arch in enumerate(architectures):
|
||||
if arch == "TransformersModel":
|
||||
if arch == "TransformersForCausalLM":
|
||||
continue
|
||||
auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
|
||||
None) or dict()
|
||||
@ -69,7 +69,7 @@ def resolve_transformers_fallback(model_config: ModelConfig,
|
||||
raise ValueError(
|
||||
f"The Transformers implementation of {arch} is not "
|
||||
"compatible with vLLM.")
|
||||
architectures[i] = "TransformersModel"
|
||||
architectures[i] = "TransformersForCausalLM"
|
||||
if model_config.model_impl == ModelImpl.AUTO:
|
||||
if not is_transformers_impl_compatible(arch, custom_model_module):
|
||||
raise ValueError(
|
||||
@ -80,7 +80,7 @@ def resolve_transformers_fallback(model_config: ModelConfig,
|
||||
"%s has no vLLM implementation, falling back to Transformers "
|
||||
"implementation. Some features may not be supported and "
|
||||
"performance may not be optimal.", arch)
|
||||
architectures[i] = "TransformersModel"
|
||||
architectures[i] = "TransformersForCausalLM"
|
||||
return architectures
|
||||
|
||||
|
||||
|
@ -201,7 +201,7 @@ _SPECULATIVE_DECODING_MODELS = {
|
||||
}
|
||||
|
||||
_FALLBACK_MODEL = {
|
||||
"TransformersModel": ("transformers", "TransformersModel"),
|
||||
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
|
||||
}
|
||||
# yapf: enable
|
||||
|
||||
@ -425,7 +425,7 @@ class _ModelRegistry:
|
||||
|
||||
# make sure Transformers fallback are put at the last
|
||||
if len(normalized_arch) != len(architectures):
|
||||
normalized_arch.append("TransformersModel")
|
||||
normalized_arch.append("TransformersForCausalLM")
|
||||
return normalized_arch
|
||||
|
||||
def inspect_model_cls(
|
||||
|
@ -43,7 +43,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, maybe_prefix)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -110,13 +111,9 @@ def replace_linear_class(
|
||||
)
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
embedding_modules = ["embed_tokens"
|
||||
] # TODO transformers will have a util to get it
|
||||
class TransformersModel(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
logger.info("Using Transformers backend.")
|
||||
|
||||
@ -134,9 +131,6 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
self.parallel_config = parallel_config
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.vocab_size = model_config.get_vocab_size()
|
||||
self.unpadded_vocab_size = model_config.get_vocab_size()
|
||||
|
||||
self.pp_group = get_pp_group()
|
||||
self.pp_size = self.pp_group.world_size
|
||||
self.pp_rank = self.pp_group.rank_in_group
|
||||
@ -144,13 +138,15 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
|
||||
# Use meta device to delay allocating GPU tensors
|
||||
with torch.device("meta"):
|
||||
# FIXME(Isotr0py): We need to refactor this part in the future to
|
||||
# avoid registering an extra model layer, otherwise we will need a
|
||||
# weights mapper to rename weights.
|
||||
self.model: PreTrainedModel = AutoModel.from_config(
|
||||
config,
|
||||
attn_implementation="vllm",
|
||||
torch_dtype=model_config.dtype,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
)
|
||||
prefix = self.model.base_model_prefix
|
||||
|
||||
self.pipeline_parallel()
|
||||
self.tensor_parallel()
|
||||
@ -168,32 +164,12 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
# Attention layers
|
||||
self.attention_instances = self.create_attention_instances()
|
||||
|
||||
# Output embeddings
|
||||
if not isinstance(getattr(self, "lm_head", None), PPMissingLayer):
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.lm_head.tie_weights(
|
||||
self.model.get_input_embeddings())
|
||||
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size,
|
||||
logit_scale)
|
||||
|
||||
# Initialize buffers (e.g. rotary embedding inverse frequency)
|
||||
self.init_buffers(self.model)
|
||||
|
||||
# Move remaining meta tensors to device (should happen last)
|
||||
self.meta_to_empty(self.model)
|
||||
|
||||
self.sampler = get_sampler()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||
config.hidden_size))
|
||||
@ -248,9 +224,6 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
if not self.pp_group.is_last_rank:
|
||||
setattr(self.model, name, PPMissingLayer())
|
||||
|
||||
if not self.pp_group.is_last_rank:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
def tensor_parallel(self):
|
||||
"""
|
||||
Apply the model's tensor parallelization plan.
|
||||
@ -331,6 +304,9 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
for child in module.children():
|
||||
self.meta_to_empty(child)
|
||||
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.model.get_input_embeddings()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
@ -361,21 +337,6 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(self, logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
|
||||
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
@ -393,3 +354,93 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
|
||||
SupportsPP):
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
embedding_modules = ["embed_tokens"
|
||||
] # TODO transformers will have a util to get it
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config: PretrainedConfig = vllm_config.model_config.hf_config
|
||||
quant_config: QuantizationConfig = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
|
||||
self.model = TransformersModel(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.lm_head.tie_weights(
|
||||
self.model.get_input_embeddings())
|
||||
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size,
|
||||
logit_scale)
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.sampler = get_sampler()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
# FIXME(Isotr0py): Don't use any weights mapper for Transformers fallback,
|
||||
# this makes thing complicated. We need to remove this mapper after refactor
|
||||
# `TransformersModel` in the future.
|
||||
@property
|
||||
def hf_to_vllm_mapper(self):
|
||||
prefix_mapper = {
|
||||
name: "model." + name
|
||||
for name, _ in self.model.model.named_children()
|
||||
}
|
||||
return WeightsMapper(
|
||||
orig_to_new_substr={"model.": "model.model."},
|
||||
orig_to_new_prefix=prefix_mapper,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
model_output = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return model_output
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(self, logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
|
||||
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
Loading…
x
Reference in New Issue
Block a user