Separate base model from TransformersModel
(#15467)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Isotr0py <2037008807@qq.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
4ec2cee000
commit
cf5c8f1686
@ -57,10 +57,10 @@ llm = LLM(model=..., task="generate") # Name or path of your model
|
|||||||
llm.apply_model(lambda model: print(type(model)))
|
llm.apply_model(lambda model: print(type(model)))
|
||||||
```
|
```
|
||||||
|
|
||||||
If it is `TransformersModel` then it means it's based on Transformers!
|
If it is `TransformersForCausalLM` then it means it's based on Transformers!
|
||||||
|
|
||||||
:::{tip}
|
:::{tip}
|
||||||
You can force the use of `TransformersModel` by setting `model_impl="transformers"` for <project:#offline-inference> or `--model-impl transformers` for the <project:#openai-compatible-server>.
|
You can force the use of `TransformersForCausalLM` by setting `model_impl="transformers"` for <project:#offline-inference> or `--model-impl transformers` for the <project:#openai-compatible-server>.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
:::{note}
|
:::{note}
|
||||||
@ -119,7 +119,7 @@ Here is what happens in the background:
|
|||||||
|
|
||||||
1. The config is loaded
|
1. The config is loaded
|
||||||
2. `MyModel` Python class is loaded from the `auto_map`, and we check that the model `_supports_attention_backend`.
|
2. `MyModel` Python class is loaded from the `auto_map`, and we check that the model `_supports_attention_backend`.
|
||||||
3. The `TransformersModel` backend is used. See <gh-file:vllm/model_executor/models/transformers.py>, which leverage `self.config._attn_implementation = "vllm"`, thus the need to use `ALL_ATTENTION_FUNCTION`.
|
3. The `TransformersForCausalLM` backend is used. See <gh-file:vllm/model_executor/models/transformers.py>, which leverage `self.config._attn_implementation = "vllm"`, thus the need to use `ALL_ATTENTION_FUNCTION`.
|
||||||
|
|
||||||
To make your model compatible with tensor parallel, it needs:
|
To make your model compatible with tensor parallel, it needs:
|
||||||
|
|
||||||
|
@ -175,7 +175,7 @@ TEXT_GENERATION_MODELS = {
|
|||||||
"inceptionai/jais-13b-chat": PPTestSettings.fast(),
|
"inceptionai/jais-13b-chat": PPTestSettings.fast(),
|
||||||
"ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
|
"ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
|
||||||
"meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(),
|
"meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(),
|
||||||
# Tests TransformersModel
|
# Tests TransformersForCausalLM
|
||||||
"ArthurZ/Ilama-3.2-1B": PPTestSettings.fast(),
|
"ArthurZ/Ilama-3.2-1B": PPTestSettings.fast(),
|
||||||
"openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(),
|
"openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(),
|
||||||
"openbmb/MiniCPM3-4B": PPTestSettings.fast(),
|
"openbmb/MiniCPM3-4B": PPTestSettings.fast(),
|
||||||
|
@ -319,7 +319,7 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_FALLBACK_MODEL = {
|
_FALLBACK_MODEL = {
|
||||||
"TransformersModel": _HfExamplesInfo("ArthurZ/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
|
"TransformersForCausalLM": _HfExamplesInfo("ArthurZ/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
|
||||||
}
|
}
|
||||||
|
|
||||||
_EXAMPLE_MODELS = {
|
_EXAMPLE_MODELS = {
|
||||||
|
@ -45,7 +45,7 @@ def is_transformers_impl_compatible(
|
|||||||
def resolve_transformers_fallback(model_config: ModelConfig,
|
def resolve_transformers_fallback(model_config: ModelConfig,
|
||||||
architectures: list[str]):
|
architectures: list[str]):
|
||||||
for i, arch in enumerate(architectures):
|
for i, arch in enumerate(architectures):
|
||||||
if arch == "TransformersModel":
|
if arch == "TransformersForCausalLM":
|
||||||
continue
|
continue
|
||||||
auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
|
auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
|
||||||
None) or dict()
|
None) or dict()
|
||||||
@ -69,7 +69,7 @@ def resolve_transformers_fallback(model_config: ModelConfig,
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The Transformers implementation of {arch} is not "
|
f"The Transformers implementation of {arch} is not "
|
||||||
"compatible with vLLM.")
|
"compatible with vLLM.")
|
||||||
architectures[i] = "TransformersModel"
|
architectures[i] = "TransformersForCausalLM"
|
||||||
if model_config.model_impl == ModelImpl.AUTO:
|
if model_config.model_impl == ModelImpl.AUTO:
|
||||||
if not is_transformers_impl_compatible(arch, custom_model_module):
|
if not is_transformers_impl_compatible(arch, custom_model_module):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -80,7 +80,7 @@ def resolve_transformers_fallback(model_config: ModelConfig,
|
|||||||
"%s has no vLLM implementation, falling back to Transformers "
|
"%s has no vLLM implementation, falling back to Transformers "
|
||||||
"implementation. Some features may not be supported and "
|
"implementation. Some features may not be supported and "
|
||||||
"performance may not be optimal.", arch)
|
"performance may not be optimal.", arch)
|
||||||
architectures[i] = "TransformersModel"
|
architectures[i] = "TransformersForCausalLM"
|
||||||
return architectures
|
return architectures
|
||||||
|
|
||||||
|
|
||||||
|
@ -201,7 +201,7 @@ _SPECULATIVE_DECODING_MODELS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_FALLBACK_MODEL = {
|
_FALLBACK_MODEL = {
|
||||||
"TransformersModel": ("transformers", "TransformersModel"),
|
"TransformersForCausalLM": ("transformers", "TransformersForCausalLM"),
|
||||||
}
|
}
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
|
|
||||||
@ -425,7 +425,7 @@ class _ModelRegistry:
|
|||||||
|
|
||||||
# make sure Transformers fallback are put at the last
|
# make sure Transformers fallback are put at the last
|
||||||
if len(normalized_arch) != len(architectures):
|
if len(normalized_arch) != len(architectures):
|
||||||
normalized_arch.append("TransformersModel")
|
normalized_arch.append("TransformersForCausalLM")
|
||||||
return normalized_arch
|
return normalized_arch
|
||||||
|
|
||||||
def inspect_model_cls(
|
def inspect_model_cls(
|
||||||
|
@ -43,7 +43,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
|
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
|
||||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
|
||||||
|
is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, maybe_prefix)
|
make_empty_intermediate_tensors_factory, maybe_prefix)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -110,13 +111,9 @@ def replace_linear_class(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile
|
class TransformersModel(nn.Module):
|
||||||
class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|
||||||
embedding_padding_modules = ["lm_head"]
|
|
||||||
embedding_modules = ["embed_tokens"
|
|
||||||
] # TODO transformers will have a util to get it
|
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
logger.info("Using Transformers backend.")
|
logger.info("Using Transformers backend.")
|
||||||
|
|
||||||
@ -134,9 +131,6 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
self.parallel_config = parallel_config
|
self.parallel_config = parallel_config
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
|
||||||
self.vocab_size = model_config.get_vocab_size()
|
|
||||||
self.unpadded_vocab_size = model_config.get_vocab_size()
|
|
||||||
|
|
||||||
self.pp_group = get_pp_group()
|
self.pp_group = get_pp_group()
|
||||||
self.pp_size = self.pp_group.world_size
|
self.pp_size = self.pp_group.world_size
|
||||||
self.pp_rank = self.pp_group.rank_in_group
|
self.pp_rank = self.pp_group.rank_in_group
|
||||||
@ -144,13 +138,15 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
|
|
||||||
# Use meta device to delay allocating GPU tensors
|
# Use meta device to delay allocating GPU tensors
|
||||||
with torch.device("meta"):
|
with torch.device("meta"):
|
||||||
|
# FIXME(Isotr0py): We need to refactor this part in the future to
|
||||||
|
# avoid registering an extra model layer, otherwise we will need a
|
||||||
|
# weights mapper to rename weights.
|
||||||
self.model: PreTrainedModel = AutoModel.from_config(
|
self.model: PreTrainedModel = AutoModel.from_config(
|
||||||
config,
|
config,
|
||||||
attn_implementation="vllm",
|
attn_implementation="vllm",
|
||||||
torch_dtype=model_config.dtype,
|
torch_dtype=model_config.dtype,
|
||||||
trust_remote_code=model_config.trust_remote_code,
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
)
|
)
|
||||||
prefix = self.model.base_model_prefix
|
|
||||||
|
|
||||||
self.pipeline_parallel()
|
self.pipeline_parallel()
|
||||||
self.tensor_parallel()
|
self.tensor_parallel()
|
||||||
@ -168,32 +164,12 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
# Attention layers
|
# Attention layers
|
||||||
self.attention_instances = self.create_attention_instances()
|
self.attention_instances = self.create_attention_instances()
|
||||||
|
|
||||||
# Output embeddings
|
|
||||||
if not isinstance(getattr(self, "lm_head", None), PPMissingLayer):
|
|
||||||
self.unpadded_vocab_size = config.vocab_size
|
|
||||||
self.lm_head = ParallelLMHead(
|
|
||||||
config.vocab_size,
|
|
||||||
config.hidden_size,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=maybe_prefix(prefix, "lm_head"),
|
|
||||||
)
|
|
||||||
if config.tie_word_embeddings:
|
|
||||||
self.lm_head = self.lm_head.tie_weights(
|
|
||||||
self.model.get_input_embeddings())
|
|
||||||
|
|
||||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
|
||||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
|
||||||
config.vocab_size,
|
|
||||||
logit_scale)
|
|
||||||
|
|
||||||
# Initialize buffers (e.g. rotary embedding inverse frequency)
|
# Initialize buffers (e.g. rotary embedding inverse frequency)
|
||||||
self.init_buffers(self.model)
|
self.init_buffers(self.model)
|
||||||
|
|
||||||
# Move remaining meta tensors to device (should happen last)
|
# Move remaining meta tensors to device (should happen last)
|
||||||
self.meta_to_empty(self.model)
|
self.meta_to_empty(self.model)
|
||||||
|
|
||||||
self.sampler = get_sampler()
|
|
||||||
|
|
||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
make_empty_intermediate_tensors_factory(["hidden_states"],
|
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||||
config.hidden_size))
|
config.hidden_size))
|
||||||
@ -248,9 +224,6 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
if not self.pp_group.is_last_rank:
|
if not self.pp_group.is_last_rank:
|
||||||
setattr(self.model, name, PPMissingLayer())
|
setattr(self.model, name, PPMissingLayer())
|
||||||
|
|
||||||
if not self.pp_group.is_last_rank:
|
|
||||||
self.lm_head = PPMissingLayer()
|
|
||||||
|
|
||||||
def tensor_parallel(self):
|
def tensor_parallel(self):
|
||||||
"""
|
"""
|
||||||
Apply the model's tensor parallelization plan.
|
Apply the model's tensor parallelization plan.
|
||||||
@ -331,6 +304,9 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
for child in module.children():
|
for child in module.children():
|
||||||
self.meta_to_empty(child)
|
self.meta_to_empty(child)
|
||||||
|
|
||||||
|
def get_input_embeddings(self) -> nn.Module:
|
||||||
|
return self.model.get_input_embeddings()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.Tensor],
|
input_ids: Optional[torch.Tensor],
|
||||||
@ -361,21 +337,6 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
def compute_logits(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata,
|
|
||||||
) -> Optional[torch.Tensor]:
|
|
||||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
|
||||||
sampling_metadata)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def sample(self, logits: torch.Tensor,
|
|
||||||
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
|
|
||||||
|
|
||||||
next_tokens = self.sampler(logits, sampling_metadata)
|
|
||||||
return next_tokens
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
torch.Tensor]]) -> set[str]:
|
torch.Tensor]]) -> set[str]:
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
@ -393,3 +354,93 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
|||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
loaded_params.add(name)
|
loaded_params.add(name)
|
||||||
return loaded_params
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
|
class TransformersForCausalLM(nn.Module, SupportsQuant, SupportsLoRA,
|
||||||
|
SupportsPP):
|
||||||
|
embedding_padding_modules = ["lm_head"]
|
||||||
|
embedding_modules = ["embed_tokens"
|
||||||
|
] # TODO transformers will have a util to get it
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
config: PretrainedConfig = vllm_config.model_config.hf_config
|
||||||
|
quant_config: QuantizationConfig = vllm_config.quant_config
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.model = TransformersModel(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
|
||||||
|
if get_pp_group().is_last_rank:
|
||||||
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=maybe_prefix(prefix, "lm_head"),
|
||||||
|
)
|
||||||
|
if config.tie_word_embeddings:
|
||||||
|
self.lm_head = self.lm_head.tie_weights(
|
||||||
|
self.model.get_input_embeddings())
|
||||||
|
|
||||||
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||||
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
|
config.vocab_size,
|
||||||
|
logit_scale)
|
||||||
|
else:
|
||||||
|
self.lm_head = PPMissingLayer()
|
||||||
|
|
||||||
|
self.sampler = get_sampler()
|
||||||
|
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
self.model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
|
# FIXME(Isotr0py): Don't use any weights mapper for Transformers fallback,
|
||||||
|
# this makes thing complicated. We need to remove this mapper after refactor
|
||||||
|
# `TransformersModel` in the future.
|
||||||
|
@property
|
||||||
|
def hf_to_vllm_mapper(self):
|
||||||
|
prefix_mapper = {
|
||||||
|
name: "model." + name
|
||||||
|
for name, _ in self.model.model.named_children()
|
||||||
|
}
|
||||||
|
return WeightsMapper(
|
||||||
|
orig_to_new_substr={"model.": "model.model."},
|
||||||
|
orig_to_new_prefix=prefix_mapper,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.Tensor],
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
|
model_output = self.model(input_ids, positions, intermediate_tensors,
|
||||||
|
inputs_embeds)
|
||||||
|
return model_output
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def sample(self, logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
|
||||||
|
|
||||||
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
torch.Tensor]]) -> set[str]:
|
||||||
|
loader = AutoWeightsLoader(
|
||||||
|
self,
|
||||||
|
skip_prefixes=(["lm_head."]
|
||||||
|
if self.config.tie_word_embeddings else None),
|
||||||
|
)
|
||||||
|
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user