[Model] PP support for embedding models and update docs (#9090)
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
This commit is contained in:
parent
f22619fe96
commit
b22b798471
@ -7,10 +7,12 @@ vLLM supports a variety of generative Transformer models in `HuggingFace Transfo
|
||||
The following is the list of model architectures that are currently supported by vLLM.
|
||||
Alongside each architecture, we include some popular models that use it.
|
||||
|
||||
----
|
||||
Text-only Language Models
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Text Generation
|
||||
---------------
|
||||
|
||||
Decoder-only Language Models
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
.. list-table::
|
||||
:widths: 25 25 50 5 5
|
||||
:header-rows: 1
|
||||
@ -40,6 +42,11 @@ Decoder-only Language Models
|
||||
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`BartForConditionalGeneration`
|
||||
- BART
|
||||
- :code:`facebook/bart-base`, :code:`facebook/bart-large-cnn`, etc.
|
||||
-
|
||||
-
|
||||
* - :code:`ChatGLMModel`
|
||||
- ChatGLM
|
||||
- :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc.
|
||||
@ -259,11 +266,55 @@ Decoder-only Language Models
|
||||
.. note::
|
||||
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
|
||||
|
||||
.. _supported_vlms:
|
||||
Text Embedding
|
||||
--------------
|
||||
|
||||
.. list-table::
|
||||
:widths: 25 25 50 5 5
|
||||
:header-rows: 1
|
||||
|
||||
* - Architecture
|
||||
- Models
|
||||
- Example HuggingFace Models
|
||||
- :ref:`LoRA <lora>`
|
||||
- :ref:`PP <distributed_serving>`
|
||||
* - :code:`Gemma2Model`
|
||||
- Gemma2-based
|
||||
- :code:`BAAI/bge-multilingual-gemma2`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
* - :code:`MistralModel`
|
||||
- Mistral-based
|
||||
- :code:`intfloat/e5-mistral-7b-instruct`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
|
||||
Reward Modeling
|
||||
---------------
|
||||
|
||||
.. list-table::
|
||||
:widths: 25 25 50 5 5
|
||||
:header-rows: 1
|
||||
|
||||
* - Architecture
|
||||
- Models
|
||||
- Example HuggingFace Models
|
||||
- :ref:`LoRA <lora>`
|
||||
- :ref:`PP <distributed_serving>`
|
||||
* - :code:`Qwen2ForRewardModel`
|
||||
- Qwen2-based
|
||||
- :code:`Qwen/Qwen2.5-Math-RM-72B`, etc.
|
||||
-
|
||||
- ✅︎
|
||||
|
||||
.. note::
|
||||
As an interim measure, these models are supported via Embeddings API. See `this RFC <https://github.com/vllm-project/vllm/issues/8967>`_ for upcoming changes.
|
||||
|
||||
Multimodal Language Models
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. _supported_vlms:
|
||||
|
||||
.. list-table::
|
||||
:widths: 25 25 25 25 5 5
|
||||
:header-rows: 1
|
||||
@ -378,6 +429,7 @@ Multimodal Language Models
|
||||
For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
|
||||
For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630
|
||||
|
||||
----
|
||||
|
||||
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
|
||||
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` and :ref:`Enabling Multimodal Inputs <enabling_multimodal_inputs>`
|
||||
|
@ -6,10 +6,9 @@ Using VLMs
|
||||
vLLM provides experimental support for Vision Language Models (VLMs). See the :ref:`list of supported VLMs here <supported_vlms>`.
|
||||
This document shows you how to run and serve these models using vLLM.
|
||||
|
||||
.. important::
|
||||
We are actively iterating on VLM support. Expect breaking changes to VLM usage and development in upcoming releases without prior deprecation.
|
||||
|
||||
We are continuously improving user & developer experience for VLMs. Please `open an issue on GitHub <https://github.com/vllm-project/vllm/issues/new/choose>`_ if you have any feedback or feature requests.
|
||||
.. note::
|
||||
We are actively iterating on VLM support. See `this RFC <https://github.com/vllm-project/vllm/issues/4194>`_ for upcoming changes,
|
||||
and `open an issue on GitHub <https://github.com/vllm-project/vllm/issues/new/choose>`_ if you have any feedback or feature requests.
|
||||
|
||||
Offline Inference
|
||||
-----------------
|
||||
|
@ -7,7 +7,7 @@ WARNING: This test runs in both single-node (4 GPUs) and multi-node
|
||||
"""
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import List, NamedTuple, Optional
|
||||
from typing import List, Literal, NamedTuple, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
@ -97,6 +97,9 @@ class PPTestSettings:
|
||||
self.trust_remote_code, self.tokenizer_mode)
|
||||
|
||||
|
||||
# NOTE: You can adjust tp_base and/or pp_base locally to fit the model in GPU
|
||||
# The values displayed here are only a rough indicator of the size of the model
|
||||
|
||||
# yapf: disable
|
||||
GENERATION_MODEL_SETTINGS = {
|
||||
# [DETAILED TESTS]
|
||||
@ -104,15 +107,13 @@ GENERATION_MODEL_SETTINGS = {
|
||||
# [FAST TESTS]
|
||||
# Uses Llama
|
||||
# "BAAI/AquilaChat-7B": PPTestSettings.fast(),
|
||||
# TODO: Test on larger GPU
|
||||
# "Snowflake/snowflake-arctic-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
||||
"Snowflake/snowflake-arctic-instruct": PPTestSettings.fast(tp_base=8, trust_remote_code=True), # noqa: E501
|
||||
"baichuan-inc/Baichuan-7B": PPTestSettings.fast(trust_remote_code=True),
|
||||
"baichuan-inc/Baichuan2-13B-Chat": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
||||
"bigscience/bloomz-1b1": PPTestSettings.fast(),
|
||||
"THUDM/chatglm3-6b": PPTestSettings.fast(trust_remote_code=True),
|
||||
"CohereForAI/c4ai-command-r-v01": PPTestSettings.fast(tp_base=2, trust_remote_code=True), # noqa: E501
|
||||
# TODO: Test on larger GPU
|
||||
# "databricks/dbrx-instruct": PPTestSettings.fast(),
|
||||
"databricks/dbrx-instruct": PPTestSettings.fast(tp_base=8),
|
||||
"Deci/DeciLM-7B-instruct": PPTestSettings.fast(trust_remote_code=True),
|
||||
"deepseek-ai/deepseek-llm-7b-chat": PPTestSettings.fast(),
|
||||
"deepseek-ai/DeepSeek-V2-Lite-Chat": PPTestSettings.fast(trust_remote_code=True), # noqa: E501
|
||||
@ -161,8 +162,9 @@ GENERATION_MODEL_SETTINGS = {
|
||||
|
||||
EMBEDDING_MODEL_SETTINGS = { # type: ignore[var-annotated]
|
||||
# [FAST TESTS]
|
||||
# Uses Llama
|
||||
# "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(),
|
||||
"intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(),
|
||||
"BAAI/bge-multilingual-gemma2": PPTestSettings.fast(),
|
||||
"Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(tp_base=4, trust_remote_code=True), # noqa: E501
|
||||
}
|
||||
|
||||
MULTIMODAL_MODEL_SETTINGS = {
|
||||
@ -192,40 +194,35 @@ CONDITIONAL_GENERATION_MODEL_SETTINGS = { # type: ignore[var-annotated]
|
||||
}
|
||||
# yapf: enable
|
||||
|
||||
MODEL_SETTINGS = {
|
||||
**GENERATION_MODEL_SETTINGS,
|
||||
**EMBEDDING_MODEL_SETTINGS,
|
||||
**MULTIMODAL_MODEL_SETTINGS,
|
||||
}
|
||||
|
||||
# You can update this on your local machine to run specific tests
|
||||
# NOTE: You can update this on your local machine to run specific tests
|
||||
TEST_MODELS = [
|
||||
# [LANGUAGE GENERATION]
|
||||
"meta-llama/Meta-Llama-3-8B",
|
||||
"facebook/chameleon-7b",
|
||||
"ibm/PowerLM-3b",
|
||||
# [LANGUAGE EMBEDDING]
|
||||
"intfloat/e5-mistral-7b-instruct",
|
||||
"BAAI/bge-multilingual-gemma2",
|
||||
# [MULTIMODAL GENERATION]
|
||||
"OpenGVLab/InternVL2-1B",
|
||||
"microsoft/Phi-3-vision-128k-instruct",
|
||||
"mistralai/Pixtral-12B-2409",
|
||||
"fixie-ai/ultravox-v0_3",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_name", "parallel_setup", "distributed_backend",
|
||||
"trust_remote_code", "tokenizer_mode"),
|
||||
[
|
||||
params for model_name, settings in MODEL_SETTINGS.items()
|
||||
for params in settings.iter_params(model_name)
|
||||
if model_name in TEST_MODELS
|
||||
],
|
||||
)
|
||||
@fork_new_process_for_each_test
|
||||
def test_compare_tp(model_name: str, parallel_setup: ParallelSetup,
|
||||
distributed_backend: str, trust_remote_code: bool,
|
||||
tokenizer_mode: Optional[str], num_gpus_available):
|
||||
def _compare_tp(
|
||||
model_name: str,
|
||||
parallel_setup: ParallelSetup,
|
||||
distributed_backend: str,
|
||||
trust_remote_code: bool,
|
||||
tokenizer_mode: Optional[str],
|
||||
num_gpus_available: int,
|
||||
*,
|
||||
method: Literal["generate", "encode"] = "encode",
|
||||
):
|
||||
tp_size, pp_size, eager_mode, chunked_prefill = parallel_setup
|
||||
|
||||
if num_gpus_available < tp_size:
|
||||
pytest.skip(f"Need at least {tp_size} GPUs to run the test")
|
||||
if num_gpus_available < tp_size * pp_size:
|
||||
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")
|
||||
if VLLM_MULTI_NODE and distributed_backend == "mp":
|
||||
pytest.skip("Skipping multi-node pipeline parallel test for "
|
||||
"multiprocessing distributed backend")
|
||||
@ -286,10 +283,95 @@ def test_compare_tp(model_name: str, parallel_setup: ParallelSetup,
|
||||
]
|
||||
|
||||
try:
|
||||
compare_two_settings(model_name, pp_args, tp_args, pp_env)
|
||||
compare_two_settings(model_name,
|
||||
pp_args,
|
||||
tp_args,
|
||||
pp_env,
|
||||
method=method)
|
||||
except Exception:
|
||||
if pp_env is None:
|
||||
raise
|
||||
else:
|
||||
# Ray ADAG tests are flaky, so we don't want to fail the test
|
||||
logger.exception("Ray ADAG tests failed")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_name", "parallel_setup", "distributed_backend",
|
||||
"trust_remote_code", "tokenizer_mode"),
|
||||
[
|
||||
params for model_name, settings in GENERATION_MODEL_SETTINGS.items()
|
||||
for params in settings.iter_params(model_name)
|
||||
if model_name in TEST_MODELS
|
||||
],
|
||||
)
|
||||
@fork_new_process_for_each_test
|
||||
def test_tp_language_generation(
|
||||
model_name: str,
|
||||
parallel_setup: ParallelSetup,
|
||||
distributed_backend: str,
|
||||
trust_remote_code: bool,
|
||||
tokenizer_mode: Optional[str],
|
||||
num_gpus_available,
|
||||
):
|
||||
_compare_tp(model_name,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
trust_remote_code,
|
||||
tokenizer_mode,
|
||||
num_gpus_available,
|
||||
method="generate")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_name", "parallel_setup", "distributed_backend",
|
||||
"trust_remote_code", "tokenizer_mode"),
|
||||
[
|
||||
params for model_name, settings in EMBEDDING_MODEL_SETTINGS.items()
|
||||
for params in settings.iter_params(model_name)
|
||||
if model_name in TEST_MODELS
|
||||
],
|
||||
)
|
||||
@fork_new_process_for_each_test
|
||||
def test_tp_language_embedding(
|
||||
model_name: str,
|
||||
parallel_setup: ParallelSetup,
|
||||
distributed_backend: str,
|
||||
trust_remote_code: bool,
|
||||
tokenizer_mode: Optional[str],
|
||||
num_gpus_available,
|
||||
):
|
||||
_compare_tp(model_name,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
trust_remote_code,
|
||||
tokenizer_mode,
|
||||
num_gpus_available,
|
||||
method="encode")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_name", "parallel_setup", "distributed_backend",
|
||||
"trust_remote_code", "tokenizer_mode"),
|
||||
[
|
||||
params for model_name, settings in MULTIMODAL_MODEL_SETTINGS.items()
|
||||
for params in settings.iter_params(model_name)
|
||||
if model_name in TEST_MODELS
|
||||
],
|
||||
)
|
||||
@fork_new_process_for_each_test
|
||||
def test_tp_multimodal_generation(
|
||||
model_name: str,
|
||||
parallel_setup: ParallelSetup,
|
||||
distributed_backend: str,
|
||||
trust_remote_code: bool,
|
||||
tokenizer_mode: Optional[str],
|
||||
num_gpus_available,
|
||||
):
|
||||
_compare_tp(model_name,
|
||||
parallel_setup,
|
||||
distributed_backend,
|
||||
trust_remote_code,
|
||||
tokenizer_mode,
|
||||
num_gpus_available,
|
||||
method="generate")
|
||||
|
229
tests/utils.py
229
tests/utils.py
@ -8,13 +8,13 @@ import time
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
import requests
|
||||
from openai.types.completion import Completion
|
||||
from typing_extensions import ParamSpec
|
||||
from typing_extensions import ParamSpec, assert_never
|
||||
|
||||
from tests.models.utils import TextTextLogprobs
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
@ -163,11 +163,140 @@ class RemoteOpenAIServer:
|
||||
)
|
||||
|
||||
|
||||
def _test_completion(
|
||||
client: openai.OpenAI,
|
||||
model: str,
|
||||
prompt: str,
|
||||
token_ids: List[int],
|
||||
):
|
||||
results = []
|
||||
|
||||
# test with text prompt
|
||||
completion = client.completions.create(model=model,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0)
|
||||
|
||||
results.append({
|
||||
"test": "single_completion",
|
||||
"text": completion.choices[0].text,
|
||||
"finish_reason": completion.choices[0].finish_reason,
|
||||
"usage": completion.usage,
|
||||
})
|
||||
|
||||
# test using token IDs
|
||||
completion = client.completions.create(
|
||||
model=model,
|
||||
prompt=token_ids,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
results.append({
|
||||
"test": "token_ids",
|
||||
"text": completion.choices[0].text,
|
||||
"finish_reason": completion.choices[0].finish_reason,
|
||||
"usage": completion.usage,
|
||||
})
|
||||
|
||||
# test seeded random sampling
|
||||
completion = client.completions.create(model=model,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
seed=33,
|
||||
temperature=1.0)
|
||||
|
||||
results.append({
|
||||
"test": "seeded_sampling",
|
||||
"text": completion.choices[0].text,
|
||||
"finish_reason": completion.choices[0].finish_reason,
|
||||
"usage": completion.usage,
|
||||
})
|
||||
|
||||
# test seeded random sampling with multiple prompts
|
||||
completion = client.completions.create(model=model,
|
||||
prompt=[prompt, prompt],
|
||||
max_tokens=5,
|
||||
seed=33,
|
||||
temperature=1.0)
|
||||
|
||||
results.append({
|
||||
"test":
|
||||
"seeded_sampling",
|
||||
"text": [choice.text for choice in completion.choices],
|
||||
"finish_reason":
|
||||
[choice.finish_reason for choice in completion.choices],
|
||||
"usage":
|
||||
completion.usage,
|
||||
})
|
||||
|
||||
# test simple list
|
||||
batch = client.completions.create(
|
||||
model=model,
|
||||
prompt=[prompt, prompt],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
results.append({
|
||||
"test": "simple_list",
|
||||
"text0": batch.choices[0].text,
|
||||
"text1": batch.choices[1].text,
|
||||
})
|
||||
|
||||
# test streaming
|
||||
batch = client.completions.create(
|
||||
model=model,
|
||||
prompt=[prompt, prompt],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
texts = [""] * 2
|
||||
for chunk in batch:
|
||||
assert len(chunk.choices) == 1
|
||||
choice = chunk.choices[0]
|
||||
texts[choice.index] += choice.text
|
||||
|
||||
results.append({
|
||||
"test": "streaming",
|
||||
"texts": texts,
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _test_embeddings(
|
||||
client: openai.OpenAI,
|
||||
model: str,
|
||||
text: str,
|
||||
):
|
||||
results = []
|
||||
|
||||
# test with text input
|
||||
embeddings = client.embeddings.create(
|
||||
model=model,
|
||||
input=text,
|
||||
encoding_format="float",
|
||||
)
|
||||
|
||||
results.append({
|
||||
"test": "single_embedding",
|
||||
"embedding": embeddings.data[0].embedding,
|
||||
"usage": embeddings.usage,
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def compare_two_settings(model: str,
|
||||
arg1: List[str],
|
||||
arg2: List[str],
|
||||
env1: Optional[Dict[str, str]] = None,
|
||||
env2: Optional[Dict[str, str]] = None,
|
||||
*,
|
||||
method: Literal["generate", "encode"] = "generate",
|
||||
max_wait_seconds: Optional[float] = None) -> None:
|
||||
"""
|
||||
Launch API server with two different sets of arguments/environments
|
||||
@ -219,96 +348,12 @@ def compare_two_settings(model: str,
|
||||
"root": served_model.root,
|
||||
})
|
||||
|
||||
# test with text prompt
|
||||
completion = client.completions.create(model=model,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
temperature=0.0)
|
||||
|
||||
results.append({
|
||||
"test": "single_completion",
|
||||
"text": completion.choices[0].text,
|
||||
"finish_reason": completion.choices[0].finish_reason,
|
||||
"usage": completion.usage,
|
||||
})
|
||||
|
||||
# test using token IDs
|
||||
completion = client.completions.create(
|
||||
model=model,
|
||||
prompt=token_ids,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
results.append({
|
||||
"test": "token_ids",
|
||||
"text": completion.choices[0].text,
|
||||
"finish_reason": completion.choices[0].finish_reason,
|
||||
"usage": completion.usage,
|
||||
})
|
||||
|
||||
# test seeded random sampling
|
||||
completion = client.completions.create(model=model,
|
||||
prompt=prompt,
|
||||
max_tokens=5,
|
||||
seed=33,
|
||||
temperature=1.0)
|
||||
|
||||
results.append({
|
||||
"test": "seeded_sampling",
|
||||
"text": completion.choices[0].text,
|
||||
"finish_reason": completion.choices[0].finish_reason,
|
||||
"usage": completion.usage,
|
||||
})
|
||||
|
||||
# test seeded random sampling with multiple prompts
|
||||
completion = client.completions.create(model=model,
|
||||
prompt=[prompt, prompt],
|
||||
max_tokens=5,
|
||||
seed=33,
|
||||
temperature=1.0)
|
||||
|
||||
results.append({
|
||||
"test":
|
||||
"seeded_sampling",
|
||||
"text": [choice.text for choice in completion.choices],
|
||||
"finish_reason":
|
||||
[choice.finish_reason for choice in completion.choices],
|
||||
"usage":
|
||||
completion.usage,
|
||||
})
|
||||
|
||||
# test simple list
|
||||
batch = client.completions.create(
|
||||
model=model,
|
||||
prompt=[prompt, prompt],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
|
||||
results.append({
|
||||
"test": "simple_list",
|
||||
"text0": batch.choices[0].text,
|
||||
"text1": batch.choices[1].text,
|
||||
})
|
||||
|
||||
# test streaming
|
||||
batch = client.completions.create(
|
||||
model=model,
|
||||
prompt=[prompt, prompt],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
stream=True,
|
||||
)
|
||||
texts = [""] * 2
|
||||
for chunk in batch:
|
||||
assert len(chunk.choices) == 1
|
||||
choice = chunk.choices[0]
|
||||
texts[choice.index] += choice.text
|
||||
results.append({
|
||||
"test": "streaming",
|
||||
"texts": texts,
|
||||
})
|
||||
if method == "generate":
|
||||
results += _test_completion(client, model, prompt, token_ids)
|
||||
elif method == "encode":
|
||||
results += _test_embeddings(client, model, prompt)
|
||||
else:
|
||||
assert_never(method)
|
||||
|
||||
n = len(results) // 2
|
||||
arg1_results = results[:n]
|
||||
|
@ -40,7 +40,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
from .utils import (group_weights_with_prefix, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -273,7 +273,7 @@ class Gemma2Model(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
@ -308,6 +308,49 @@ class Gemma2Model(nn.Module):
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
for (param_name, shard_name, shard_id) in stacked_params_mapping:
|
||||
if shard_name not in name:
|
||||
continue
|
||||
name = name.replace(shard_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
|
||||
unloaded_params = params_dict.keys() - loaded_params
|
||||
if unloaded_params:
|
||||
logger.warning(
|
||||
"Some weights are not initialized from checkpoints: %s",
|
||||
unloaded_params)
|
||||
|
||||
|
||||
class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
@ -391,48 +434,19 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
for (param_name, shard_name, shard_id) in stacked_params_mapping:
|
||||
if shard_name not in name:
|
||||
weights_group = group_weights_with_prefix(weights)
|
||||
|
||||
self.model.load_weights(weights_group["model"])
|
||||
|
||||
if not self.config.tie_word_embeddings:
|
||||
# NOTE: For now self.lm_head is not defined because
|
||||
# tie_word_embeddings is assumed to the False
|
||||
lm_head_dict = dict(self.lm_head.named_parameters())
|
||||
for name, loaded_weight in weights_group["lm_head"]:
|
||||
if is_pp_missing_parameter(name, self.lm_head):
|
||||
continue
|
||||
name = name.replace(shard_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# lm_head is not used in vllm as it is tied with embed_token.
|
||||
# To prevent errors, skip loading lm_head.weight.
|
||||
if "lm_head.weight" in name:
|
||||
continue
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
|
||||
param = lm_head_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
|
||||
unloaded_params = params_dict.keys() - loaded_params
|
||||
if unloaded_params:
|
||||
logger.warning(
|
||||
"Some weights are not initialized from checkpoints: %s",
|
||||
unloaded_params)
|
||||
|
@ -1,17 +1,18 @@
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.gemma2 import Gemma2Model
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
|
||||
from .gemma2 import Gemma2Model
|
||||
from .interfaces import SupportsPP
|
||||
|
||||
class Gemma2EmbeddingModel(nn.Module):
|
||||
|
||||
class Gemma2EmbeddingModel(nn.Module, SupportsPP):
|
||||
"""A model that uses Gemma2 with additional embedding functionalities.
|
||||
|
||||
This class encapsulates the Gemma2Model and provides an interface for
|
||||
@ -30,6 +31,9 @@ class Gemma2EmbeddingModel(nn.Module):
|
||||
self.model = Gemma2Model(**kwargs)
|
||||
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
@ -38,10 +42,9 @@ class Gemma2EmbeddingModel(nn.Module):
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.model.forward(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
return self.model(input_ids, positions, kv_caches, attn_metadata,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
@ -51,32 +54,4 @@ class Gemma2EmbeddingModel(nn.Module):
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.model.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
self.model.load_weights(weights)
|
||||
|
@ -51,7 +51,8 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_hip
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
from .utils import (PPMissingLayer, group_weights_with_prefix,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
@ -347,6 +348,90 @@ class LlamaModel(nn.Module):
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if ("rotary_emb.cos_cached" in name
|
||||
or "rotary_emb.sin_cached" in name):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
if scale_name := get_compressed_tensors_cache_scale(name):
|
||||
# Loading kv cache scales for compressed-tensors quantization
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
loaded_weight = loaded_weight[0]
|
||||
weight_loader(param, loaded_weight)
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# If this function is called, it should always initialize KV cache scale
|
||||
# factors (or else raise an exception). Thus, handled exceptions should
|
||||
# make sure to leave KV cache scale factors in a known good (dummy) state
|
||||
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
for layer_idx, scaling_factor in kv_cache_scales_loader(
|
||||
quantization_param_path, tp_rank, tp_size,
|
||||
self.config.num_hidden_layers,
|
||||
self.config.__class__.model_type):
|
||||
if not isinstance(self.layers[layer_idx], nn.Identity):
|
||||
layer_self_attn = self.layers[layer_idx].self_attn
|
||||
|
||||
if is_hip():
|
||||
# The scaling factor convention we are assuming is
|
||||
# quantized_value * scaling_factor ~= true_value
|
||||
# which is consistent with the practice of setting
|
||||
# scaling_factor = tensor_amax / FPtype_max
|
||||
scaling_factor *= 2
|
||||
if hasattr(layer_self_attn, "kv_scale"):
|
||||
layer_self_attn.attn._kv_scale = scaling_factor
|
||||
else:
|
||||
raise RuntimeError("Self attention has no KV cache scaling "
|
||||
"factor attribute!")
|
||||
|
||||
|
||||
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
@ -372,6 +457,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
# Mistral/Llama models can also be loaded with --load-format mistral
|
||||
# from consolidated.safetensors checkpoints
|
||||
mistral_mapping = {
|
||||
@ -465,103 +551,38 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
weights = [
|
||||
self.maybe_remap_mistral(name, loaded_weight)
|
||||
for name, loaded_weight in weights
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
name, loaded_weight = self.maybe_remap_mistral(name, loaded_weight)
|
||||
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if ("rotary_emb.cos_cached" in name
|
||||
or "rotary_emb.sin_cached" in name):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
# With tie_word_embeddings, we can skip lm_head.weight
|
||||
# The weight might appear unnecessarily in the files if the model is
|
||||
# processed with quantization, LoRA, fine-tuning, etc.
|
||||
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||
continue
|
||||
if scale_name := get_compressed_tensors_cache_scale(name):
|
||||
# Loading kv cache scales for compressed-tensors quantization
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
loaded_weight = loaded_weight[0]
|
||||
weight_loader(param, loaded_weight)
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
weights_group = group_weights_with_prefix(weights)
|
||||
|
||||
self.model.load_weights(weights_group["model"])
|
||||
|
||||
if not self.config.tie_word_embeddings:
|
||||
lm_head_dict = dict(self.lm_head.named_parameters())
|
||||
for name, loaded_weight in weights_group["lm_head"]:
|
||||
if is_pp_missing_parameter(name, self.lm_head):
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
param = lm_head_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# If this function is called, it should always initialize KV cache scale
|
||||
# factors (or else raise an exception). Thus, handled exceptions should
|
||||
# make sure to leave KV cache scale factors in a known good (dummy) state
|
||||
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
for layer_idx, scaling_factor in kv_cache_scales_loader(
|
||||
quantization_param_path, tp_rank, tp_size,
|
||||
self.config.num_hidden_layers,
|
||||
self.config.__class__.model_type):
|
||||
if not isinstance(self.model.layers[layer_idx], nn.Identity):
|
||||
layer_self_attn = self.model.layers[layer_idx].self_attn
|
||||
|
||||
if is_hip():
|
||||
# The scaling factor convention we are assuming is
|
||||
# quantized_value * scaling_factor ~= true_value
|
||||
# which is consistent with the practice of setting
|
||||
# scaling_factor = tensor_amax / FPtype_max
|
||||
scaling_factor *= 2
|
||||
if hasattr(layer_self_attn, "kv_scale"):
|
||||
layer_self_attn.attn._kv_scale = scaling_factor
|
||||
else:
|
||||
raise RuntimeError("Self attention has no KV cache scaling "
|
||||
"factor attribute!")
|
||||
self.model.load_kv_cache_scales(quantization_param_path)
|
||||
|
||||
# This function is used to remap the mistral format as
|
||||
# used by Mistral and Llama <=2
|
||||
def maybe_remap_mistral(
|
||||
self, name: str,
|
||||
loaded_weight: torch.Tensor) -> Tuple[str, torch.Tensor]:
|
||||
self,
|
||||
name: str,
|
||||
loaded_weight: torch.Tensor,
|
||||
) -> Tuple[str, torch.Tensor]:
|
||||
|
||||
def permute(w, n_heads):
|
||||
def permute(w: torch.Tensor, n_heads: int):
|
||||
attn_in = self.config.head_dim * n_heads
|
||||
attn_out = self.config.hidden_size
|
||||
|
||||
|
@ -5,13 +5,11 @@ from torch import nn
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import is_pp_missing_parameter
|
||||
from .llama import LlamaModel
|
||||
|
||||
|
||||
class LlamaEmbeddingModel(nn.Module, SupportsPP):
|
||||
@ -44,9 +42,8 @@ class LlamaEmbeddingModel(nn.Module, SupportsPP):
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
return self.model.forward(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return self.model(input_ids, positions, kv_caches, attn_metadata,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
@ -56,43 +53,7 @@ class LlamaEmbeddingModel(nn.Module, SupportsPP):
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.model.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if ("rotary_emb.cos_cached" in name
|
||||
or "rotary_emb.sin_cached" in name):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
self.model.load_weights(weights)
|
||||
|
||||
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
||||
self.model.load_kv_cache_scales(quantization_param_path)
|
||||
|
@ -48,7 +48,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||
from .utils import (PPMissingLayer, group_weights_with_prefix,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers)
|
||||
|
||||
|
||||
@ -300,6 +301,47 @@ class Qwen2Model(nn.Module):
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
@ -393,44 +435,17 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
weights_group = group_weights_with_prefix(weights)
|
||||
|
||||
self.model.load_weights(weights_group["model"])
|
||||
|
||||
if not self.config.tie_word_embeddings:
|
||||
lm_head_dict = dict(self.lm_head.named_parameters())
|
||||
for name, loaded_weight in weights_group["lm_head"]:
|
||||
if is_pp_missing_parameter(name, self.lm_head):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
|
||||
param = lm_head_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
@ -4,7 +4,7 @@
|
||||
# Copyright 2024 The Qwen team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
"""Inference-only Qwen2-RM model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -15,15 +15,14 @@ from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.qwen2 import Qwen2Model
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
|
||||
from .utils import is_pp_missing_parameter
|
||||
from .interfaces import SupportsPP
|
||||
from .qwen2 import Qwen2Model
|
||||
from .utils import group_weights_with_prefix
|
||||
|
||||
|
||||
class ReLU(nn.Module):
|
||||
@ -37,7 +36,7 @@ class ReLU(nn.Module):
|
||||
return self.activation(input)
|
||||
|
||||
|
||||
class Qwen2ForRewardModel(nn.Module):
|
||||
class Qwen2ForRewardModel(nn.Module, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -97,6 +96,9 @@ class Qwen2ForRewardModel(nn.Module):
|
||||
)
|
||||
self._pooler = Pooler(pooling_type=PoolingType.ALL, normalize=False)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@ -104,7 +106,7 @@ class Qwen2ForRewardModel(nn.Module):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors)
|
||||
logits, _ = self.score(hidden_states)
|
||||
@ -118,45 +120,13 @@ class Qwen2ForRewardModel(nn.Module):
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
for name, loaded_weight in weights:
|
||||
# Skip loading lm_head for embedding model
|
||||
if name == "lm_head.weight":
|
||||
continue
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
weights_group = group_weights_with_prefix(weights)
|
||||
|
||||
self.model.load_weights(weights_group["model"])
|
||||
|
||||
score_dict = dict(self.score.named_parameters())
|
||||
for name, loaded_weight in weights_group["score"]:
|
||||
param = score_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
@ -306,10 +306,12 @@ def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]:
|
||||
|
||||
def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
|
||||
"""Check if a parameter is missing in a pipeline parallel model."""
|
||||
for missing_layer_name in get_pp_missing_layer_names(model):
|
||||
if name.startswith(missing_layer_name):
|
||||
return True
|
||||
return False
|
||||
if isinstance(model, PPMissingLayer):
|
||||
return True
|
||||
|
||||
return any(
|
||||
name.startswith(missing_layer_name)
|
||||
for missing_layer_name in get_pp_missing_layer_names(model))
|
||||
|
||||
|
||||
def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
|
||||
|
@ -1,11 +1,12 @@
|
||||
import dataclasses
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
@ -66,7 +67,7 @@ class EmbeddingModelRunner(
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[PoolerOutput]]:
|
||||
) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError(
|
||||
"EmbeddingModelRunner does not support multi-step execution.")
|
||||
@ -107,28 +108,52 @@ class EmbeddingModelRunner(
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
|
||||
execute_model_kwargs = {
|
||||
"input_ids":
|
||||
model_input.input_tokens,
|
||||
"positions":
|
||||
model_input.input_positions,
|
||||
"kv_caches":
|
||||
kv_caches,
|
||||
"attn_metadata":
|
||||
model_input.attn_metadata,
|
||||
**MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {},
|
||||
device=self.device),
|
||||
}
|
||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_start = torch.cuda.Event(enable_timing=True)
|
||||
model_forward_end = torch.cuda.Event(enable_timing=True)
|
||||
model_forward_start.record()
|
||||
|
||||
with set_forward_context(model_input.attn_metadata):
|
||||
hidden_states = model_executable(**execute_model_kwargs)
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=model_input.attn_metadata,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalInputs.as_kwargs(multi_modal_kwargs,
|
||||
device=self.device))
|
||||
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_end.record()
|
||||
|
||||
# Only perform pooling in the last pipeline stage.
|
||||
if not get_pp_group().is_last_rank:
|
||||
if (self.is_driver_worker
|
||||
and hidden_or_intermediate_states is not None
|
||||
and isinstance(hidden_or_intermediate_states,
|
||||
IntermediateTensors)
|
||||
and self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_end.synchronize()
|
||||
model_forward_time = model_forward_start.elapsed_time(
|
||||
model_forward_end)
|
||||
orig_model_forward_time = 0.0
|
||||
if intermediate_tensors is not None:
|
||||
orig_model_forward_time = intermediate_tensors.tensors.get(
|
||||
"model_forward_time", torch.tensor(0.0)).item()
|
||||
hidden_or_intermediate_states.tensors["model_forward_time"] = (
|
||||
torch.tensor(model_forward_time + orig_model_forward_time))
|
||||
return hidden_or_intermediate_states
|
||||
|
||||
# Only perform pooling in the driver worker.
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
return [
|
||||
self.model.pooler(hidden_states=hidden_states,
|
||||
self.model.pooler(hidden_states=hidden_or_intermediate_states,
|
||||
pooling_metadata=model_input.pooling_metadata)
|
||||
]
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user