[Model]: Add transformers
backend support (#11330)
# Adds support for `transformers` as a backend Following https://github.com/huggingface/transformers/pull/35235, a bunch of models should already be supported, we are ramping up support for more models. Thanks @Isotr0py for the TP support, and @hmellor for his help as well! This includes: - `trust_remote_code=True` support: any model on the hub, if it implements attention the correct way can be natively supported!! - tensor parallel support --------- Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Isotr0py <41363108+Isotr0py@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Isotr0py <2037008807@qq.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
1298a400e8
commit
a1a2aaadb9
@ -349,6 +349,7 @@ steps:
|
||||
- vllm/
|
||||
- tests/models
|
||||
commands:
|
||||
- pytest -v -s models/test_transformers.py
|
||||
- pytest -v -s models/test_registry.py
|
||||
- pytest -v -s models/test_initialization.py
|
||||
|
||||
@ -485,6 +486,7 @@ steps:
|
||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
|
||||
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
|
||||
# Avoid importing model tests that cause CUDA reinitialization error
|
||||
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)'
|
||||
|
@ -40,6 +40,82 @@ If vLLM successfully returns text (for generative models) or hidden states (for
|
||||
Otherwise, please refer to [Adding a New Model](#new-model) for instructions on how to implement your model in vLLM.
|
||||
Alternatively, you can [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) to request vLLM support.
|
||||
|
||||
### Transformers fallback
|
||||
|
||||
After the merge of <gh-pr:11330>, `vllm` can fallback to models that are available in `transformers`. This does not work for all models for now, but most decoder language models are supported, and vision language model support is planned!
|
||||
|
||||
To check if the backend is `transformers`, you can simply do this:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
llm = LLM(model=..., task="generate") # Name or path of your model
|
||||
llm.apply_model(lambda model: print(model.__class__))
|
||||
```
|
||||
|
||||
If it is `TransformersModel` then it means it's based on `transformers`!
|
||||
|
||||
#### Supported features
|
||||
|
||||
##### LORA and quantization
|
||||
|
||||
Both are not supported yet! Make sure to open an issue and we'll work on this together with the `transformers` team!
|
||||
|
||||
Usually `transformers` model load weights via the `load_adapters` API, that depends on PEFT. We need to work a bit to either use this api (for now this would result in some weights not being marked as loaded) or replace modules accordingly.
|
||||
|
||||
Hints as to how this would look like:
|
||||
|
||||
```python
|
||||
class TransformersModel(nn.Module, SupportsLoRA):
|
||||
def __init__(*):
|
||||
...
|
||||
self.model.load_adapter(vllm_config.load_config.model_loader_extra_config["qlora_adapter_name_or_path"])
|
||||
```
|
||||
|
||||
Blocker is that you need to specify supported lora layers, when we would ideally want to load whatever is inside the checkpoint!
|
||||
|
||||
##### Remote code
|
||||
|
||||
This fallback also means that any model on the hub that can be used in `transformers` with `trust_remote_code=True` that correctly implements attention can be used in production!
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
llm = LLM(model=..., task="generate", trust_remote_code=True) # Name or path of your model
|
||||
llm.apply_model(lambda model: print(model.__class__))
|
||||
```
|
||||
|
||||
A model just needs the following two things:
|
||||
|
||||
```python
|
||||
from transformers import PreTrainedModel
|
||||
from torch import nn
|
||||
|
||||
class MyAttention(nn.Module):
|
||||
|
||||
def forward(self, hidden_states, **kwargs): # <- kwargs are required
|
||||
|
||||
...
|
||||
attention_interface = attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
**kwargs,
|
||||
)
|
||||
...
|
||||
|
||||
class MyModel(PreTrainedModel):
|
||||
_supports_attention_backend = True
|
||||
```
|
||||
|
||||
Here is what happens in the background:
|
||||
|
||||
1. The config is loaded
|
||||
2. `MyModel` python class is loaded from the `auto_map`, and we check that the model `_supports_attention_backend`.
|
||||
3. The `TransformersModel` backend is used. See `/model_executors/models/transformers`, which leverage `self.config._attn_implementation = "vllm"`, thus the need to use `ALL_ATTENTION_FUNCTION`.
|
||||
|
||||
That's it!
|
||||
|
||||
### ModelScope
|
||||
|
||||
To use models from [ModelScope](https://www.modelscope.cn) instead of HuggingFace Hub, set an environment variable:
|
||||
|
@ -5,7 +5,7 @@ requests >= 2.26.0
|
||||
tqdm
|
||||
blake3
|
||||
py-cpuinfo
|
||||
transformers >= 4.48.2 # Required for Bamba.
|
||||
transformers >= 4.48.2 # Required for Bamba model and Transformers backend.
|
||||
tokenizers >= 0.19.1 # Required for Llama 3.
|
||||
protobuf # Required by LlamaTokenizer.
|
||||
fastapi >= 0.107.0, < 0.113.0; python_version < '3.9'
|
||||
|
@ -281,12 +281,17 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
speculative_model="ibm-fms/llama-160m-accelerator"), # noqa: E501
|
||||
}
|
||||
|
||||
_FALLBACK_MODEL = {
|
||||
"TransformersModel": _HfExamplesInfo("ArthurZ/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
|
||||
}
|
||||
|
||||
_EXAMPLE_MODELS = {
|
||||
**_TEXT_GENERATION_EXAMPLE_MODELS,
|
||||
**_EMBEDDING_EXAMPLE_MODELS,
|
||||
**_CROSS_ENCODER_EXAMPLE_MODELS,
|
||||
**_MULTIMODAL_EXAMPLE_MODELS,
|
||||
**_SPECULATIVE_DECODING_EXAMPLE_MODELS,
|
||||
**_FALLBACK_MODEL,
|
||||
}
|
||||
|
||||
|
||||
|
@ -15,7 +15,9 @@ def test_plugin(dummy_opt_path):
|
||||
os.environ["VLLM_PLUGINS"] = ""
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
LLM(model=dummy_opt_path, load_format="dummy")
|
||||
assert "are not supported for now" in str(excinfo.value)
|
||||
error_msg = "has no vLLM implementation and " \
|
||||
"the Transformers implementation is not compatible with vLLM."
|
||||
assert (error_msg in str(excinfo.value))
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
|
75
tests/models/test_transformers.py
Normal file
75
tests/models/test_transformers.py
Normal file
@ -0,0 +1,75 @@
|
||||
"""Test the functionality of the Transformers backend.
|
||||
|
||||
Run `pytest tests/models/test_transformers.py`.
|
||||
"""
|
||||
from contextlib import nullcontext
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
|
||||
from ..conftest import HfRunner, VllmRunner
|
||||
from ..utils import multi_gpu_test
|
||||
from .utils import check_logprobs_close
|
||||
|
||||
|
||||
def check_implementation(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
example_prompts: list[str],
|
||||
model: str,
|
||||
**kwargs,
|
||||
):
|
||||
max_tokens = 32
|
||||
num_logprobs = 5
|
||||
|
||||
with vllm_runner(model, **kwargs) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
with hf_runner(model) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,model_impl",
|
||||
[
|
||||
("meta-llama/Llama-3.2-1B-Instruct", "transformers"),
|
||||
("openai-community/gpt2", "transformers"),
|
||||
("ArthurZ/Ilama-3.2-1B", "auto"), # CUSTOM CODE
|
||||
("meta-llama/Llama-3.2-1B-Instruct", "auto"),
|
||||
]) # trust_remote_code=True by default
|
||||
def test_models(hf_runner, vllm_runner, example_prompts, model,
|
||||
model_impl) -> None:
|
||||
|
||||
maybe_raises = nullcontext()
|
||||
if model == "openai-community/gpt2" and model_impl == "transformers":
|
||||
# Model is not backend compatible
|
||||
maybe_raises = pytest.raises(
|
||||
ValueError,
|
||||
match="The Transformers implementation.*not compatible with vLLM")
|
||||
|
||||
with maybe_raises:
|
||||
check_implementation(hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model,
|
||||
model_impl=model_impl)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
def test_distributed(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
):
|
||||
kwargs = {"model_impl": "transformers", "tensor_parallel_size": 2}
|
||||
check_implementation(hf_runner, vllm_runner, example_prompts,
|
||||
"meta-llama/Llama-3.2-1B-Instruct", **kwargs)
|
@ -83,6 +83,12 @@ class SupportsHash(Protocol):
|
||||
...
|
||||
|
||||
|
||||
class ModelImpl(str, enum.Enum):
|
||||
AUTO = "auto"
|
||||
VLLM = "vllm"
|
||||
TRANSFORMERS = "transformers"
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
"""Configuration for the model.
|
||||
|
||||
@ -167,6 +173,12 @@ class ModelConfig:
|
||||
`logits_processors` extra completion argument. Defaults to None,
|
||||
which allows no processors.
|
||||
generation_config: Configuration parameter file for generation.
|
||||
model_impl: Which implementation of the model to use:
|
||||
"auto" will try to use the vLLM implementation if it exists and
|
||||
fall back to the Transformers implementation if no vLLM
|
||||
implementation is available.
|
||||
"vllm" will use the vLLM model implementation.
|
||||
"transformers" will use the Transformers model implementation.
|
||||
override_generation_config: Override the generation config with the
|
||||
given config.
|
||||
"""
|
||||
@ -230,6 +242,7 @@ class ModelConfig:
|
||||
generation_config: Optional[str] = None,
|
||||
enable_sleep_mode: bool = False,
|
||||
override_generation_config: Optional[Dict[str, Any]] = None,
|
||||
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
@ -241,6 +254,7 @@ class ModelConfig:
|
||||
self.code_revision = code_revision
|
||||
self.rope_scaling = rope_scaling
|
||||
self.rope_theta = rope_theta
|
||||
self.model_impl = model_impl
|
||||
|
||||
if hf_overrides is None:
|
||||
hf_overrides = {}
|
||||
|
@ -13,10 +13,10 @@ import vllm.envs as envs
|
||||
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
|
||||
DecodingConfig, DeviceConfig, HfOverrides,
|
||||
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
|
||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||
PoolerConfig, PromptAdapterConfig, SchedulerConfig,
|
||||
SpeculativeConfig, TaskOption, TokenizerPoolConfig,
|
||||
VllmConfig)
|
||||
ModelConfig, ModelImpl, ObservabilityConfig,
|
||||
ParallelConfig, PoolerConfig, PromptAdapterConfig,
|
||||
SchedulerConfig, SpeculativeConfig, TaskOption,
|
||||
TokenizerPoolConfig, VllmConfig)
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
@ -199,6 +199,7 @@ class EngineArgs:
|
||||
generation_config: Optional[str] = None
|
||||
override_generation_config: Optional[Dict[str, Any]] = None
|
||||
enable_sleep_mode: bool = False
|
||||
model_impl: str = "auto"
|
||||
|
||||
calculate_kv_scales: Optional[bool] = None
|
||||
|
||||
@ -378,6 +379,18 @@ class EngineArgs:
|
||||
'qualified names that can be passed with the `logits_processors` '
|
||||
'extra completion argument. Defaults to None, which allows no '
|
||||
'processors.')
|
||||
parser.add_argument(
|
||||
'--model-impl',
|
||||
type=str,
|
||||
default=EngineArgs.model_impl,
|
||||
choices=[f.value for f in ModelImpl],
|
||||
help='Which implementation of the model to use.\n\n'
|
||||
'* "auto" will try to use the vLLM implementation if it exists '
|
||||
'and fall back to the Transformers implementation if no vLLM '
|
||||
'implementation is available.\n'
|
||||
'* "vllm" will use the vLLM model implementation.\n'
|
||||
'* "transformers" will use the Transformers model '
|
||||
'implementation.\n')
|
||||
# Parallel arguments
|
||||
parser.add_argument(
|
||||
'--distributed-executor-backend',
|
||||
@ -1017,6 +1030,7 @@ class EngineArgs:
|
||||
generation_config=self.generation_config,
|
||||
override_generation_config=self.override_generation_config,
|
||||
enable_sleep_mode=self.enable_sleep_mode,
|
||||
model_impl=self.model_impl,
|
||||
)
|
||||
|
||||
def create_load_config(self) -> LoadConfig:
|
||||
|
@ -2,17 +2,22 @@
|
||||
"""Utilities for selecting and loading models."""
|
||||
import contextlib
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Tuple, Type
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from torch import nn
|
||||
from transformers.dynamic_module_utils import get_class_from_dynamic_module
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config import ModelConfig, ModelImpl
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.model_executor.models.adapters import (as_classification_model,
|
||||
as_embedding_model,
|
||||
as_reward_model)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_default_torch_dtype(dtype: torch.dtype):
|
||||
@ -23,6 +28,50 @@ def set_default_torch_dtype(dtype: torch.dtype):
|
||||
torch.set_default_dtype(old_dtype)
|
||||
|
||||
|
||||
def is_transformers_impl_compatible(
|
||||
arch: str,
|
||||
module: Optional[transformers.PreTrainedModel] = None) -> bool:
|
||||
mod = module or getattr(transformers, arch, None)
|
||||
if mod is None:
|
||||
return False
|
||||
if hasattr(mod, "supports_backend"):
|
||||
return mod.is_backend_compatible()
|
||||
else:
|
||||
return mod._supports_flex_attn
|
||||
|
||||
|
||||
def resolve_transformers_fallback(model_config: ModelConfig,
|
||||
architectures: list[str]):
|
||||
for i, arch in enumerate(architectures):
|
||||
if arch == "TransformersModel":
|
||||
continue
|
||||
custom_module = None
|
||||
auto_map = getattr(model_config.hf_config, "auto_map", None)
|
||||
if auto_map is not None and "AutoModel" in auto_map:
|
||||
custom_module = get_class_from_dynamic_module(
|
||||
model_config.hf_config.auto_map["AutoModel"],
|
||||
model_config.model)
|
||||
# TODO(Isotr0py): Further clean up these raises.
|
||||
# perhaps handled them in _ModelRegistry._raise_for_unsupported?
|
||||
if model_config.model_impl == ModelImpl.TRANSFORMERS:
|
||||
if not is_transformers_impl_compatible(arch, custom_module):
|
||||
raise ValueError(
|
||||
f"The Transformers implementation of {arch} is not "
|
||||
"compatible with vLLM.")
|
||||
architectures[i] = "TransformersModel"
|
||||
if model_config.model_impl == ModelImpl.AUTO:
|
||||
if not is_transformers_impl_compatible(arch, custom_module):
|
||||
raise ValueError(
|
||||
f"{arch} has no vLLM implementation and the Transformers "
|
||||
"implementation is not compatible with vLLM.")
|
||||
logger.warning(
|
||||
"%s has no vLLM implementation, falling back to Transformers "
|
||||
"implementation. Some features may not be supported and "
|
||||
"performance may not be optimal.", arch)
|
||||
architectures[i] = "TransformersModel"
|
||||
return architectures
|
||||
|
||||
|
||||
def get_model_architecture(
|
||||
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
|
||||
architectures = getattr(model_config.hf_config, "architectures", [])
|
||||
@ -38,6 +87,14 @@ def get_model_architecture(
|
||||
and "MixtralForCausalLM" in architectures):
|
||||
architectures = ["QuantMixtralForCausalLM"]
|
||||
|
||||
vllm_supported_archs = ModelRegistry.get_supported_archs()
|
||||
is_vllm_supported = any(arch in vllm_supported_archs
|
||||
for arch in architectures)
|
||||
if (not is_vllm_supported
|
||||
or model_config.model_impl == ModelImpl.TRANSFORMERS):
|
||||
architectures = resolve_transformers_fallback(model_config,
|
||||
architectures)
|
||||
|
||||
model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
|
||||
if model_config.task == "embed":
|
||||
model_cls = as_embedding_model(model_cls)
|
||||
|
@ -184,6 +184,10 @@ _SPECULATIVE_DECODING_MODELS = {
|
||||
"MedusaModel": ("medusa", "Medusa"),
|
||||
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
||||
}
|
||||
|
||||
_FALLBACK_MODEL = {
|
||||
"TransformersModel": ("transformers", "TransformersModel"),
|
||||
}
|
||||
# yapf: enable
|
||||
|
||||
_VLLM_MODELS = {
|
||||
@ -192,6 +196,7 @@ _VLLM_MODELS = {
|
||||
**_CROSS_ENCODER_MODELS,
|
||||
**_MULTIMODAL_MODELS,
|
||||
**_SPECULATIVE_DECODING_MODELS,
|
||||
**_FALLBACK_MODEL,
|
||||
}
|
||||
|
||||
|
||||
@ -378,7 +383,12 @@ class _ModelRegistry:
|
||||
if not architectures:
|
||||
logger.warning("No model architectures are specified")
|
||||
|
||||
return architectures
|
||||
normalized_arch = []
|
||||
for model in architectures:
|
||||
if model not in self.models:
|
||||
model = "TransformersModel"
|
||||
normalized_arch.append(model)
|
||||
return normalized_arch
|
||||
|
||||
def inspect_model_cls(
|
||||
self,
|
||||
|
264
vllm/model_executor/models/transformers.py
Normal file
264
vllm/model_executor/models/transformers.py
Normal file
@ -0,0 +1,264 @@
|
||||
# Copyright 2024 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Wrapper around `transformers` models"""
|
||||
import re
|
||||
from typing import Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import AutoModel, PreTrainedModel
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.utils import divide
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .utils import maybe_prefix
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def vllm_flash_attention_forward(
|
||||
# Transformers args
|
||||
module: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
# Transformers kwargs
|
||||
scaling: float = None,
|
||||
# vLLM kwargs
|
||||
attn_metadata: AttentionMetadata = None,
|
||||
attention_instances: list[Attention] = None,
|
||||
**kwargs):
|
||||
self_attn = attention_instances[module.layer_idx]
|
||||
if scaling is not None:
|
||||
self_attn.impl.scale = float(scaling)
|
||||
hidden = query.shape[-2]
|
||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
||||
query, key, value = (x.reshape(hidden, -1) for x in (query, key, value))
|
||||
return self_attn.forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache=None, # argument not used
|
||||
attn_metadata=attn_metadata), None
|
||||
|
||||
|
||||
ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward
|
||||
|
||||
|
||||
# Linear Layer that is compatible with transformers internal forward
|
||||
# TODO: This is a temporary solution, we should find a better way to integrate
|
||||
class HFColumnParallelLinear(ColumnParallelLinear):
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return super().forward(input)[0]
|
||||
|
||||
|
||||
class HFRowParallelLinear(RowParallelLinear):
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return super().forward(input)[0]
|
||||
|
||||
|
||||
def replace_tp_linear_class(orig_module: nn.Linear,
|
||||
style: str,
|
||||
quant_config=None):
|
||||
"""
|
||||
In model configurations, we use a neutral type (string) to specify parallel
|
||||
styles, here we use it to translate nn.Linear into vllm-style tp Linear.
|
||||
|
||||
Quant config is not supported yet
|
||||
"""
|
||||
|
||||
if not isinstance(style, str):
|
||||
raise ValueError(
|
||||
f"Unsupported parallel style type {type(style)}, expected str")
|
||||
|
||||
input_size = orig_module.in_features
|
||||
output_size = orig_module.out_features
|
||||
bias = orig_module.bias is not None
|
||||
|
||||
if style == "colwise":
|
||||
return HFColumnParallelLinear(
|
||||
input_size,
|
||||
output_size,
|
||||
bias,
|
||||
)
|
||||
elif style == "rowwise":
|
||||
return HFRowParallelLinear(
|
||||
input_size,
|
||||
output_size,
|
||||
bias,
|
||||
)
|
||||
# We don't consider colwise_rep since it's used in lm_head
|
||||
else:
|
||||
raise ValueError(f"Unsupported parallel style value: {style}")
|
||||
|
||||
|
||||
class TransformersModel(nn.Module):
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
embedding_modules = ["embed_tokens"
|
||||
] # TODO transformers will have a util to get it
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
logger.info("Using Transformers backend.")
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.quant_config = quant_config
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
|
||||
self.model: PreTrainedModel = AutoModel.from_config(
|
||||
self.config,
|
||||
attn_implementation="vllm",
|
||||
torch_dtype=vllm_config.model_config.dtype,
|
||||
trust_remote_code=vllm_config.model_config.trust_remote_code,
|
||||
)
|
||||
prefix = self.model.base_model_prefix
|
||||
|
||||
# MLP modifications
|
||||
self.tensor_parallelize(self.model)
|
||||
|
||||
# Attention modifications (assumes 1 attention op per hidden layer)
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.attention_instances = [
|
||||
Attention(
|
||||
num_heads=divide(config.num_attention_heads, tp_size),
|
||||
head_size=config.head_dim,
|
||||
# NOTE: We use Llama scale as default, if it's set by
|
||||
# Transformers, it's updated in vllm_flash_attention_forward
|
||||
scale=config.head_dim**-0.5,
|
||||
num_kv_heads=divide(config.num_key_value_heads, tp_size),
|
||||
cache_config=cache_config,
|
||||
quant_config=None,
|
||||
prefix=f"{i}.attn") for i in range(config.num_hidden_layers)
|
||||
]
|
||||
|
||||
# Model modifications
|
||||
self.replace_vocab_embed_class(self.model)
|
||||
|
||||
# ForCausalLM modifications
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=None,
|
||||
prefix=maybe_prefix(prefix, "lm_head"))
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.get_input_embeddings().weight
|
||||
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||
config.vocab_size, logit_scale)
|
||||
self.sampler = get_sampler()
|
||||
|
||||
def log_replacement(self, name: str, old_module: nn.Module,
|
||||
new_module: nn.Module):
|
||||
logger.debug("%s: %s -> %s", name, old_module, new_module)
|
||||
|
||||
def tensor_parallelize(self, module: nn.Module, prefix: str = ""):
|
||||
if (self.config.base_model_tp_plan is None
|
||||
and self.vllm_config.parallel_config.tensor_parallel_size > 1):
|
||||
raise ValueError(
|
||||
"Trying to run tensor parallelization but the model does not "
|
||||
"support it yet!")
|
||||
|
||||
for child_name, child_module in module.named_children():
|
||||
qual_name = prefix + child_name
|
||||
for pattern, style in self.config.base_model_tp_plan.items():
|
||||
if re.match(pattern, qual_name) and isinstance(
|
||||
child_module, nn.Linear):
|
||||
new_module = replace_tp_linear_class(
|
||||
child_module, style, self.quant_config)
|
||||
setattr(module, child_name, new_module)
|
||||
self.log_replacement(qual_name, child_module, new_module)
|
||||
else:
|
||||
self.tensor_parallelize(child_module, prefix=f"{qual_name}.")
|
||||
|
||||
def replace_vocab_embed_class(self, module: nn.Module):
|
||||
# Use native set input embeddings
|
||||
new_module = VocabParallelEmbedding(
|
||||
self.vocab_size,
|
||||
self.config.hidden_size,
|
||||
org_num_embeddings=self.config.vocab_size,
|
||||
quant_config=None,
|
||||
)
|
||||
self.log_replacement("input embedding",
|
||||
self.model.get_input_embeddings(), new_module)
|
||||
self.model.set_input_embeddings(new_module)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor], # argument not used
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
model_output = self.model(
|
||||
input_ids[None, ...],
|
||||
use_cache=False,
|
||||
position_ids=positions[None, ...],
|
||||
attn_metadata=attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
attention_instances=self.attention_instances,
|
||||
return_dict=False)[0][0, ...] # we remove batch dimension for now
|
||||
return model_output
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(self, logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]:
|
||||
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if name not in params_dict:
|
||||
name = f"{self.model.base_model_prefix}.{name}"
|
||||
if name in params_dict:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
Loading…
x
Reference in New Issue
Block a user