[Model] Add user-configurable task for models that support both generation and embedding (#9424)

This commit is contained in:
Cyrus Leung 2024-10-19 02:31:58 +08:00 committed by GitHub
parent 7dbe738d65
commit 051eaf6db3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 451 additions and 201 deletions

View File

@ -294,6 +294,10 @@ Text Embedding
-
- ✅︎
.. important::
Some model architectures support both generation and embedding tasks.
In this case, you have to pass :code:`--task embedding` to run the model in embedding mode.
Reward Modeling
---------------
@ -482,6 +486,10 @@ Multimodal Embedding
- 🚧
- ✅︎
.. important::
Some model architectures support both generation and embedding tasks.
In this case, you have to pass :code:`--task embedding` to run the model in embedding mode.
----
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.

View File

@ -181,8 +181,8 @@ Below is an example on how to launch the same ``microsoft/Phi-3.5-vision-instruc
.. code-block:: bash
vllm serve microsoft/Phi-3.5-vision-instruct --max-model-len 4096 \
--trust-remote-code --limit-mm-per-prompt image=2
vllm serve microsoft/Phi-3.5-vision-instruct --task generate \
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2
.. important::
Since OpenAI Vision API is based on `Chat Completions <https://platform.openai.com/docs/api-reference/chat>`_ API,

View File

@ -7,6 +7,7 @@ prompt = "<|image_1|> Represent the given image with the following question: Wha
# Create an LLM.
llm = LLM(
model="TIGER-Lab/VLM2Vec-Full",
task="embedding",
trust_remote_code=True,
max_model_len=4096,
max_num_seqs=2,

View File

@ -7,8 +7,8 @@ Launch the vLLM server with the following command:
vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja
(multi-image inference with Phi-3.5-vision-instruct)
vllm serve microsoft/Phi-3.5-vision-instruct --max-model-len 4096 \
--trust-remote-code --limit-mm-per-prompt image=2
vllm serve microsoft/Phi-3.5-vision-instruct --task generate \
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2
(audio inference with Ultravox)
vllm serve fixie-ai/ultravox-v0_3 --max-model-len 4096

View File

@ -25,7 +25,7 @@ from tests.models.utils import (TokensTextLogprobs,
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.config import TokenizerPoolConfig
from vllm.config import TaskOption, TokenizerPoolConfig
from vllm.connections import global_http_connection
from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel,
@ -619,6 +619,7 @@ class VllmRunner:
def __init__(
self,
model_name: str,
task: TaskOption = "auto",
tokenizer_name: Optional[str] = None,
# Use smaller max model length, otherwise bigger model cannot run due
# to kv cache size limit.
@ -634,6 +635,7 @@ class VllmRunner:
) -> None:
self.model = LLM(
model=model_name,
task=task,
tokenizer=tokenizer_name,
trust_remote_code=True,
dtype=dtype,

View File

@ -33,7 +33,8 @@ def test_simple():
num_seq_group = 4
max_model_len = 16
max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(max_num_batched_tokens,
scheduler_config = SchedulerConfig("generate",
max_num_batched_tokens,
num_seq_group,
max_model_len,
enable_chunked_prefill=True)
@ -78,6 +79,7 @@ def test_chunk():
max_model_len = 80
max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(
"generate",
max_num_batched_tokens,
max_seqs,
max_model_len,
@ -126,6 +128,7 @@ def test_complex():
max_model_len = 80
max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(
"generate",
max_num_batched_tokens,
max_seqs,
max_model_len,
@ -196,6 +199,7 @@ def test_maximal_decoding():
max_model_len = 8
max_num_batched_tokens = 2
scheduler_config = SchedulerConfig(
"generate",
max_num_batched_tokens,
max_seqs,
max_model_len,
@ -289,6 +293,7 @@ def test_prompt_limit():
max_model_len = 64
max_num_batched_tokens = 32
scheduler_config = SchedulerConfig(
"generate",
max_num_batched_tokens,
max_seqs,
max_model_len,
@ -321,7 +326,8 @@ def test_prompt_limit_exceed():
max_seqs = 64
max_model_len = 32
max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(max_num_batched_tokens,
scheduler_config = SchedulerConfig("generate",
max_num_batched_tokens,
max_seqs,
max_model_len,
enable_chunked_prefill=True)
@ -348,6 +354,7 @@ def test_swap():
max_model_len = 200
max_num_batched_tokens = 30
scheduler_config = SchedulerConfig(
"generate",
max_num_batched_tokens,
max_seqs,
max_model_len,
@ -404,6 +411,7 @@ def test_running_prefill_prioritized_over_swap():
max_model_len = 200
max_num_batched_tokens = 30
scheduler_config = SchedulerConfig(
"generate",
max_num_batched_tokens,
max_seqs,
max_model_len,
@ -498,6 +506,7 @@ def test_chunked_prefill_preempt():
max_model_len = 200
max_num_batched_tokens = 30
scheduler_config = SchedulerConfig(
"generate",
max_num_batched_tokens,
max_seqs,
max_model_len,
@ -563,6 +572,7 @@ def test_chunked_prefill_max_seqs():
max_model_len = 80
max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(
"generate",
max_num_batched_tokens,
max_seqs,
max_model_len,
@ -617,6 +627,7 @@ def test_perfix_caching():
max_model_len = 80
max_num_batched_tokens = 64
scheduler_config = SchedulerConfig(
"generate",
max_num_batched_tokens,
max_seqs,
max_model_len,

View File

@ -20,9 +20,10 @@ from .utils import (append_new_token, append_new_token_seq_group,
def test_scheduler_add_seq_group():
block_size = 4
scheduler_config = SchedulerConfig(
100,
64,
1,
"generate",
max_num_batched_tokens=100,
max_num_seqs=64,
max_model_len=1,
)
cache_config = CacheConfig(block_size, 1.0, 1, cache_dtype="auto")
cache_config.num_cpu_blocks = 4
@ -42,9 +43,10 @@ def test_scheduler_add_seq_group():
def test_scheduler_abort_seq_group():
block_size = 4
scheduler_config = SchedulerConfig(
100,
64,
1,
"generate",
max_num_batched_tokens=100,
max_num_seqs=64,
max_model_len=1,
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 4
@ -70,9 +72,10 @@ def test_scheduler_schedule_simple():
num_seq_group = 4
max_model_len = 16
scheduler_config = SchedulerConfig(
64,
num_seq_group,
max_model_len,
"generate",
max_num_batched_tokens=64,
max_num_seqs=num_seq_group,
max_model_len=max_model_len,
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8
@ -114,9 +117,10 @@ def test_scheduler_prefill_prioritized():
max_model_len = 30
max_batched_num_tokens = 30
scheduler_config = SchedulerConfig(
max_batched_num_tokens,
2,
max_model_len,
"generate",
max_num_batched_tokens=max_batched_num_tokens,
max_num_seqs=2,
max_model_len=max_model_len,
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 16
@ -145,9 +149,10 @@ def test_scheduler_schedule_preempt_abort():
block_size = 4
max_model_len = 16
scheduler_config = SchedulerConfig(
64,
2,
max_model_len,
"generate",
max_num_batched_tokens=64,
max_num_seqs=2,
max_model_len=max_model_len,
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 2
@ -204,9 +209,10 @@ def test_scheduler_max_seqs():
max_seq_group = 2
max_model_len = 16
scheduler_config = SchedulerConfig(
64,
max_seq_group,
max_model_len,
"generate",
max_num_batched_tokens=64,
max_num_seqs=max_seq_group,
max_model_len=max_model_len,
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 8
@ -248,9 +254,10 @@ def test_scheduler_max_seqs():
def test_scheduler_delay_factor():
block_size = 4
scheduler_config = SchedulerConfig(
100,
64,
16,
"generate",
max_num_batched_tokens=100,
max_num_seqs=64,
max_model_len=16,
delay_factor=0.5,
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
@ -350,9 +357,10 @@ def initialize_scheduler(
):
block_size = block_size
scheduler_config = SchedulerConfig(
max_token_budget,
max_num_seqs,
max_model_len,
"generate",
max_num_batched_tokens=max_token_budget,
max_num_seqs=max_num_seqs,
max_model_len=max_model_len,
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = num_cpu_blocks

View File

@ -36,7 +36,12 @@ def test_scheduler_schedule_simple_encoder_decoder():
block_size = 4
num_seq_group = 4
max_model_len = 16
scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len)
scheduler_config = SchedulerConfig(
task="generate",
max_num_batched_tokens=64,
max_num_seqs=num_seq_group,
max_model_len=max_model_len,
)
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
cache_config.num_cpu_blocks = 16 # enc and dec prompts per seq_group
cache_config.num_gpu_blocks = 16 # enc and dec prompts per seq_group

View File

@ -11,6 +11,7 @@ from typing import List, Literal, NamedTuple, Optional
import pytest
from vllm.config import TaskOption
from vllm.logger import init_logger
from ..utils import compare_two_settings, fork_new_process_for_each_test
@ -31,6 +32,7 @@ class ParallelSetup(NamedTuple):
class PPTestSettings:
parallel_setups: List[ParallelSetup]
distributed_backends: List[str]
task: TaskOption
trust_remote_code: bool
tokenizer_mode: Optional[str]
@ -39,6 +41,7 @@ class PPTestSettings:
*,
tp_base: int = 1,
pp_base: int = 2,
task: TaskOption = "auto",
trust_remote_code: bool = False,
tokenizer_mode: Optional[str] = None,
):
@ -66,6 +69,7 @@ class PPTestSettings:
chunked_prefill=False),
],
distributed_backends=["mp", "ray"],
task=task,
trust_remote_code=trust_remote_code,
tokenizer_mode=tokenizer_mode,
)
@ -75,6 +79,7 @@ class PPTestSettings:
*,
tp_base: int = 1,
pp_base: int = 2,
task: TaskOption = "auto",
trust_remote_code: bool = False,
tokenizer_mode: Optional[str] = None,
):
@ -86,6 +91,7 @@ class PPTestSettings:
chunked_prefill=False),
],
distributed_backends=["mp"],
task=task,
trust_remote_code=trust_remote_code,
tokenizer_mode=tokenizer_mode,
)
@ -94,7 +100,7 @@ class PPTestSettings:
for parallel_setup in self.parallel_setups:
for distributed_backend in self.distributed_backends:
yield (model_name, parallel_setup, distributed_backend,
self.trust_remote_code, self.tokenizer_mode)
self.task, self.trust_remote_code, self.tokenizer_mode)
# NOTE: You can adjust tp_base and/or pp_base locally to fit the model in GPU
@ -213,6 +219,7 @@ def _compare_tp(
model_name: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
task: TaskOption,
trust_remote_code: bool,
tokenizer_mode: Optional[str],
num_gpus_available: int,
@ -240,6 +247,8 @@ def _compare_tp(
common_args.append("--enable-chunked-prefill")
if eager_mode:
common_args.append("--enforce-eager")
if task != "auto":
common_args.extend(["--task", task])
if trust_remote_code:
common_args.append("--trust-remote-code")
if tokenizer_mode:
@ -297,7 +306,7 @@ def _compare_tp(
@pytest.mark.parametrize(
("model_name", "parallel_setup", "distributed_backend",
("model_name", "parallel_setup", "distributed_backend", "task",
"trust_remote_code", "tokenizer_mode"),
[
params for model_name, settings in GENERATION_MODEL_SETTINGS.items()
@ -310,6 +319,7 @@ def test_tp_language_generation(
model_name: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
task: TaskOption,
trust_remote_code: bool,
tokenizer_mode: Optional[str],
num_gpus_available,
@ -317,6 +327,7 @@ def test_tp_language_generation(
_compare_tp(model_name,
parallel_setup,
distributed_backend,
task,
trust_remote_code,
tokenizer_mode,
num_gpus_available,
@ -324,7 +335,7 @@ def test_tp_language_generation(
@pytest.mark.parametrize(
("model_name", "parallel_setup", "distributed_backend",
("model_name", "parallel_setup", "distributed_backend", "task",
"trust_remote_code", "tokenizer_mode"),
[
params for model_name, settings in EMBEDDING_MODEL_SETTINGS.items()
@ -337,6 +348,7 @@ def test_tp_language_embedding(
model_name: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
task: TaskOption,
trust_remote_code: bool,
tokenizer_mode: Optional[str],
num_gpus_available,
@ -344,6 +356,7 @@ def test_tp_language_embedding(
_compare_tp(model_name,
parallel_setup,
distributed_backend,
task,
trust_remote_code,
tokenizer_mode,
num_gpus_available,
@ -351,7 +364,7 @@ def test_tp_language_embedding(
@pytest.mark.parametrize(
("model_name", "parallel_setup", "distributed_backend",
("model_name", "parallel_setup", "distributed_backend", "task",
"trust_remote_code", "tokenizer_mode"),
[
params for model_name, settings in MULTIMODAL_MODEL_SETTINGS.items()
@ -364,6 +377,7 @@ def test_tp_multimodal_generation(
model_name: str,
parallel_setup: ParallelSetup,
distributed_backend: str,
task: TaskOption,
trust_remote_code: bool,
tokenizer_mode: Optional[str],
num_gpus_available,
@ -371,6 +385,7 @@ def test_tp_multimodal_generation(
_compare_tp(model_name,
parallel_setup,
distributed_backend,
task,
trust_remote_code,
tokenizer_mode,
num_gpus_available,

View File

@ -0,0 +1,92 @@
from typing import List
import pytest
from vllm import LLM
from ..openai.test_vision import TEST_IMAGE_URLS
def test_chat():
llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
prompt1 = "Explain the concept of entropy."
messages = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt1
},
]
outputs = llm.chat(messages)
assert len(outputs) == 1
def test_multi_chat():
llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
prompt1 = "Explain the concept of entropy."
prompt2 = "Explain what among us is."
conversation1 = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt1
},
]
conversation2 = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt2
},
]
messages = [conversation1, conversation2]
outputs = llm.chat(messages)
assert len(outputs) == 2
@pytest.mark.parametrize("image_urls",
[[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]])
def test_chat_multi_image(image_urls: List[str]):
llm = LLM(
model="microsoft/Phi-3.5-vision-instruct",
dtype="bfloat16",
max_model_len=4096,
max_num_seqs=5,
enforce_eager=True,
trust_remote_code=True,
limit_mm_per_prompt={"image": 2},
)
messages = [{
"role":
"user",
"content": [
*({
"type": "image_url",
"image_url": {
"url": image_url
}
} for image_url in image_urls),
{
"type": "text",
"text": "What's in this image?"
},
],
}]
outputs = llm.chat(messages)
assert len(outputs) >= 0

View File

@ -6,7 +6,6 @@ import pytest
from vllm import LLM, RequestOutput, SamplingParams
from ...conftest import cleanup
from ..openai.test_vision import TEST_IMAGE_URLS
MODEL_NAME = "facebook/opt-125m"
@ -104,90 +103,3 @@ def test_multiple_sampling_params(llm: LLM):
# sampling_params is None, default params should be applied
outputs = llm.generate(PROMPTS, sampling_params=None)
assert len(PROMPTS) == len(outputs)
def test_chat():
llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
prompt1 = "Explain the concept of entropy."
messages = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt1
},
]
outputs = llm.chat(messages)
assert len(outputs) == 1
def test_multi_chat():
llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
prompt1 = "Explain the concept of entropy."
prompt2 = "Explain what among us is."
conversation1 = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt1
},
]
conversation2 = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt2
},
]
messages = [conversation1, conversation2]
outputs = llm.chat(messages)
assert len(outputs) == 2
@pytest.mark.parametrize("image_urls",
[[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]])
def test_chat_multi_image(image_urls: List[str]):
llm = LLM(
model="microsoft/Phi-3.5-vision-instruct",
dtype="bfloat16",
max_model_len=4096,
max_num_seqs=5,
enforce_eager=True,
trust_remote_code=True,
limit_mm_per_prompt={"image": 2},
)
messages = [{
"role":
"user",
"content": [
*({
"type": "image_url",
"image_url": {
"url": image_url
}
} for image_url in image_urls),
{
"type": "text",
"text": "What's in this image?"
},
],
}]
outputs = llm.chat(messages)
assert len(outputs) >= 0

View File

@ -0,0 +1,22 @@
import pytest
from vllm import LLM
from ...utils import error_on_warning
MODEL_NAME = "facebook/opt-125m"
def test_pos_args_deprecated():
with error_on_warning(DeprecationWarning):
LLM(model=MODEL_NAME, tokenizer=MODEL_NAME)
with error_on_warning(DeprecationWarning):
LLM(MODEL_NAME, tokenizer=MODEL_NAME)
with pytest.warns(DeprecationWarning, match="'tokenizer'"):
LLM(MODEL_NAME, MODEL_NAME)
with pytest.warns(DeprecationWarning,
match="'tokenizer', 'tokenizer_mode'"):
LLM(MODEL_NAME, MODEL_NAME, "auto")

View File

@ -22,12 +22,12 @@ class MockHFConfig:
@dataclass
class MockModelConfig:
task = "generate"
tokenizer = MODEL_NAME
trust_remote_code = False
tokenizer_mode = "auto"
max_model_len = 100
tokenizer_revision = None
embedding_mode = False
multimodal_config = MultiModalConfig()
hf_config = MockHFConfig()

View File

@ -23,6 +23,8 @@ TEST_IMAGE_URLS = [
@pytest.fixture(scope="module")
def server():
args = [
"--task",
"generate",
"--dtype",
"bfloat16",
"--max-model-len",

View File

@ -18,7 +18,8 @@ PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
@pytest.fixture(scope="module")
def phi3v_model_config():
return ModelConfig(PHI3V_MODEL_ID,
PHI3V_MODEL_ID,
task="generate",
tokenizer=PHI3V_MODEL_ID,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="bfloat16",

View File

@ -15,7 +15,8 @@ def test_worker_apply_lora(sql_lora_files):
worker = Worker(
model_config=ModelConfig(
"meta-llama/Llama-2-7b-hf",
"meta-llama/Llama-2-7b-hf",
task="auto",
tokenizer="meta-llama/Llama-2-7b-hf",
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
@ -27,7 +28,7 @@ def test_worker_apply_lora(sql_lora_files):
load_format="dummy",
),
parallel_config=ParallelConfig(1, 1, False),
scheduler_config=SchedulerConfig(32, 32, 32),
scheduler_config=SchedulerConfig("generate", 32, 32, 32),
device_config=DeviceConfig("cuda"),
cache_config=CacheConfig(block_size=16,
gpu_memory_utilization=1.,

View File

@ -89,6 +89,7 @@ def run_test(
# max_model_len should be greater than image_feature_size
with vllm_runner(model,
task="generate",
max_model_len=4096,
max_num_seqs=2,
dtype=dtype,

View File

@ -28,6 +28,7 @@ def test_models(
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
with vllm_runner(model,
task="embedding",
max_model_len=4096,
max_num_seqs=2,
dtype=dtype,

View File

@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Sequence, Tuple, Union
import torch
from vllm.config import ModelConfig
from vllm.config import ModelConfig, TaskOption
from vllm.inputs import InputContext
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
from vllm.utils import is_cpu
@ -248,6 +248,7 @@ def check_logprobs_close(
def build_model_context(model_name: str,
task: TaskOption = "auto",
tokenizer_name: Optional[str] = None,
trust_remote_code: bool = False,
dtype: Optional[Union[str, torch.dtype]] = None,
@ -273,7 +274,8 @@ def build_model_context(model_name: str,
model_config = ModelConfig(
model_name,
tokenizer_name,
task=task,
tokenizer=tokenizer_name,
tokenizer_mode="auto",
trust_remote_code=trust_remote_code,
dtype=dtype,

View File

@ -24,6 +24,7 @@ def test_clip_image_processor(image_assets, mm_registry, dtype, size_factor):
model_config = ModelConfig(
model=MODEL_NAME,
task="auto",
tokenizer=MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,
@ -67,6 +68,7 @@ def test_llava_next_image_processor(image_assets, mm_registry, dtype,
model_config = ModelConfig(
model=MODEL_NAME,
task="auto",
tokenizer=MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,
@ -109,6 +111,7 @@ def test_mm_limits(image_assets, mm_registry, num_images, limit, is_valid):
model_config = ModelConfig(
model=MODEL_NAME,
task="auto",
tokenizer=MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,
@ -139,6 +142,7 @@ def test_image_mapper_multi(image_assets, mm_registry, num_images):
model_config = ModelConfig(
model=MODEL_NAME,
task="auto",
tokenizer=MODEL_NAME,
tokenizer_mode="auto",
trust_remote_code=False,

View File

@ -221,6 +221,7 @@ def test_max_tokens_kwarg_overrides(num_crops):
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
ctx = build_model_context(MULTIMODAL_MODEL_ID,
task="generate",
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": 1})
@ -256,6 +257,7 @@ def test_max_tokens_kwarg_overrides(num_crops):
def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs):
"""Ensure that max token calcs filters out invalid mm_processor_kwargs"""
ctx = build_model_context(MULTIMODAL_MODEL_ID,
task="generate",
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": 1})
@ -278,12 +280,13 @@ def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs):
### Test overrides for the mapper
@pytest.mark.parametrize("num_crops", [DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE])
def test_default_mapper_with_processer_kwargs(image_assets, num_crops):
def test_default_mapper_with_processor_kwargs(image_assets, num_crops):
"""Ensure that the mapper processor kwargs can fall back to HF models."""
# NOTE - we don't validate bad inputs for the default mapper, because it's
# through the automodel interface in transformers, so we can't easily
# inspect what kwargs are or are not allowed.
ctx = build_model_context(MULTIMODAL_MODEL_ID,
task="generate",
trust_remote_code=True,
mm_processor_kwargs={"num_crops": num_crops},
limit_mm_per_prompt={"image": 1})
@ -311,6 +314,7 @@ def test_custom_mapper_kwarg_overrides(image_assets, init_num_crops,
init_num_crops, inference_num_crops)
ctx = build_model_context(MULTIMODAL_MODEL_ID,
task="generate",
trust_remote_code=True,
mm_processor_kwargs=init_kwargs,
limit_mm_per_prompt={"image": 1})
@ -348,6 +352,7 @@ def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
"""Ensure that custom mappers filters out invalid mm_processor_kwargs"""
# Should filter out the init time kwargs
ctx = build_model_context(MULTIMODAL_MODEL_ID,
task="generate",
trust_remote_code=True,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": 1})

View File

@ -57,7 +57,8 @@ def test_auto_gptq(model_arg_exptype: Tuple[str, None, str]) -> None:
try:
model_config = ModelConfig(model_path,
model_path,
task="auto",
tokenizer=model_path,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,

View File

@ -2,6 +2,42 @@ import pytest
from vllm.config import ModelConfig
@pytest.mark.parametrize(("model_id", "expected_task"), [
("facebook/opt-125m", "generate"),
("intfloat/e5-mistral-7b-instruct", "embedding"),
])
def test_auto_task(model_id, expected_task):
config = ModelConfig(
model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
)
assert config.task == expected_task
@pytest.mark.parametrize(("model_id", "bad_task"), [
("facebook/opt-125m", "embedding"),
("intfloat/e5-mistral-7b-instruct", "generate"),
])
def test_incorrect_task(model_id, bad_task):
with pytest.raises(ValueError, match=r"does not support the .* task"):
ModelConfig(
model_id,
task=bad_task,
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
)
MODEL_IDS_EXPECTED = [
("Qwen/Qwen1.5-7B", 32768),
("mistralai/Mistral-7B-v0.1", 4096),
@ -14,7 +50,8 @@ def test_disable_sliding_window(model_id_expected):
model_id, expected = model_id_expected
model_config = ModelConfig(
model_id,
model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
@ -32,7 +69,8 @@ def test_get_sliding_window():
# when use_sliding_window is False.
qwen2_model_config = ModelConfig(
"Qwen/Qwen1.5-7B",
"Qwen/Qwen1.5-7B",
task="auto",
tokenizer="Qwen/Qwen1.5-7B",
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
@ -49,7 +87,8 @@ def test_get_sliding_window():
mistral_model_config = ModelConfig(
"mistralai/Mistral-7B-v0.1",
"mistralai/Mistral-7B-v0.1",
task="auto",
tokenizer="mistralai/Mistral-7B-v0.1",
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
@ -70,7 +109,8 @@ def test_rope_customization():
llama_model_config = ModelConfig(
"meta-llama/Meta-Llama-3-8B-Instruct",
"meta-llama/Meta-Llama-3-8B-Instruct",
task="auto",
tokenizer="meta-llama/Meta-Llama-3-8B-Instruct",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float16",
@ -82,7 +122,8 @@ def test_rope_customization():
llama_model_config = ModelConfig(
"meta-llama/Meta-Llama-3-8B-Instruct",
"meta-llama/Meta-Llama-3-8B-Instruct",
task="auto",
tokenizer="meta-llama/Meta-Llama-3-8B-Instruct",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float16",
@ -98,7 +139,8 @@ def test_rope_customization():
longchat_model_config = ModelConfig(
"lmsys/longchat-13b-16k",
"lmsys/longchat-13b-16k",
task="auto",
tokenizer="lmsys/longchat-13b-16k",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float16",
@ -112,7 +154,8 @@ def test_rope_customization():
longchat_model_config = ModelConfig(
"lmsys/longchat-13b-16k",
"lmsys/longchat-13b-16k",
task="auto",
tokenizer="lmsys/longchat-13b-16k",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float16",

View File

@ -59,7 +59,7 @@ def test_deprecate_kwargs_always():
with pytest.warns(DeprecationWarning, match="'old_arg'"):
dummy(old_arg=1)
with error_on_warning():
with error_on_warning(DeprecationWarning):
dummy(new_arg=1)
@ -69,10 +69,10 @@ def test_deprecate_kwargs_never():
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with error_on_warning():
with error_on_warning(DeprecationWarning):
dummy(old_arg=1)
with error_on_warning():
with error_on_warning(DeprecationWarning):
dummy(new_arg=1)
@ -86,15 +86,15 @@ def test_deprecate_kwargs_dynamic():
with pytest.warns(DeprecationWarning, match="'old_arg'"):
dummy(old_arg=1)
with error_on_warning():
with error_on_warning(DeprecationWarning):
dummy(new_arg=1)
is_deprecated = False
with error_on_warning():
with error_on_warning(DeprecationWarning):
dummy(old_arg=1)
with error_on_warning():
with error_on_warning(DeprecationWarning):
dummy(new_arg=1)

View File

@ -8,7 +8,7 @@ import time
import warnings
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Optional, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union
import openai
import pytest
@ -454,13 +454,13 @@ def multi_process_parallel(
@contextmanager
def error_on_warning():
def error_on_warning(category: Type[Warning] = Warning):
"""
Within the scope of this context manager, tests will fail if any warning
is emitted.
of the given category is emitted.
"""
with warnings.catch_warnings():
warnings.simplefilter("error")
warnings.filterwarnings("error", category=category)
yield

View File

@ -1,8 +1,8 @@
import enum
import json
from dataclasses import dataclass, field, fields
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Mapping,
Optional, Tuple, Type, Union)
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal,
Mapping, Optional, Set, Tuple, Type, Union)
import torch
from transformers import PretrainedConfig
@ -33,6 +33,9 @@ logger = init_logger(__name__)
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
Task = Literal["generate", "embedding"]
TaskOption = Literal["auto", Task]
class ModelConfig:
"""Configuration for the model.
@ -40,7 +43,11 @@ class ModelConfig:
Args:
model: Name or path of the huggingface model to use.
It is also used as the content for `model_name` tag in metrics
output when `served_model_name` is not specified.
output when `served_model_name` is not specified.
task: The task to use the model for. Each vLLM instance only supports
one task, even if the same model can be used for multiple tasks.
When the model only supports one task, "auto" can be used to select
it; otherwise, you must specify explicitly which task to use.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, "slow" will always use the slow tokenizer, and
@ -108,6 +115,7 @@ class ModelConfig:
def __init__(self,
model: str,
task: TaskOption,
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
@ -207,7 +215,11 @@ class ModelConfig:
self.override_neuron_config = override_neuron_config if is_neuron(
) else None
self._verify_embedding_mode()
supported_tasks, task = self._resolve_task(task, self.hf_config)
self.supported_tasks = supported_tasks
self.task: Final = task
self._verify_quantization()
self._verify_cuda_graph()
self._verify_bnb_config()
@ -241,18 +253,41 @@ class ModelConfig:
"either 'auto', 'slow' or 'mistral'.")
self.tokenizer_mode = tokenizer_mode
def _verify_embedding_mode(self) -> None:
architectures = getattr(self.hf_config, "architectures", [])
def _resolve_task(
self,
task_option: TaskOption,
hf_config: PretrainedConfig,
) -> Tuple[Set[Task], Task]:
architectures = getattr(hf_config, "architectures", [])
# TODO: Allow the same model architecture to be specified as either
# generation or embedding model
if "Phi3VForCausalLM" in architectures:
# Match both remote and local names
embedding_mode = "/VLM2Vec" in self.model
task_support: Dict[Task, bool] = {
# NOTE: Listed from highest to lowest priority,
# in case the model supports multiple of them
"generate": ModelRegistry.is_text_generation_model(architectures),
"embedding": ModelRegistry.is_embedding_model(architectures),
}
supported_tasks_lst: List[Task] = [
task for task, is_supported in task_support.items() if is_supported
]
supported_tasks = set(supported_tasks_lst)
if task_option == "auto":
selected_task = next(iter(supported_tasks_lst))
if len(supported_tasks) > 1:
logger.info(
"This model supports multiple tasks: %s. "
"Defaulting to '%s'.", supported_tasks, selected_task)
else:
embedding_mode = ModelRegistry.is_embedding_model(architectures)
if task_option not in supported_tasks:
msg = (
f"This model does not support the '{task_option}' task. "
f"Supported tasks: {supported_tasks}")
raise ValueError(msg)
self.embedding_mode = embedding_mode
selected_task = task_option
return supported_tasks, selected_task
def _parse_quant_hf_config(self):
quant_cfg = getattr(self.hf_config, "quantization_config", None)
@ -401,7 +436,7 @@ class ModelConfig:
# Async postprocessor is not necessary with embedding mode
# since there is no token generation
if self.embedding_mode:
if self.task == "embedding":
self.use_async_output_proc = False
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
@ -582,11 +617,6 @@ class ModelConfig:
(hasattr(self.hf_config, "text_config") and getattr(
self.hf_config.text_config, "is_encoder_decoder", False)))
@property
def is_embedding_model(self) -> bool:
"""Extract the embedding model flag."""
return self.embedding_mode
@property
def is_multimodal_model(self) -> bool:
return self.multimodal_config is not None
@ -943,6 +973,7 @@ class SchedulerConfig:
"""Scheduler configuration.
Args:
task: The task to use the model for.
max_num_batched_tokens: Maximum number of tokens to be processed in
a single iteration.
max_num_seqs: Maximum number of sequences to be processed in a single
@ -957,7 +988,6 @@ class SchedulerConfig:
prompt latency) before scheduling next prompt.
enable_chunked_prefill: If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens.
embedding_mode: Whether the running model is for embedding.
preemption_mode: Whether to perform preemption by swapping or
recomputation. If not specified, we determine the mode as follows:
We use recomputation by default since it incurs lower overhead than
@ -972,13 +1002,13 @@ class SchedulerConfig:
"""
def __init__(self,
task: Task,
max_num_batched_tokens: Optional[int],
max_num_seqs: int,
max_model_len: int,
num_lookahead_slots: int = 0,
delay_factor: float = 0.0,
enable_chunked_prefill: bool = False,
embedding_mode: bool = False,
is_multimodal_model: bool = False,
preemption_mode: Optional[str] = None,
num_scheduler_steps: int = 1,
@ -1002,7 +1032,7 @@ class SchedulerConfig:
# for higher throughput.
max_num_batched_tokens = max(max_model_len, 2048)
if embedding_mode:
if task == "embedding":
# For embedding, choose specific value for higher throughput
max_num_batched_tokens = max(
max_num_batched_tokens,
@ -1022,12 +1052,12 @@ class SchedulerConfig:
"Chunked prefill is enabled with max_num_batched_tokens=%d.",
self.max_num_batched_tokens)
self.task: Final = task
self.max_num_seqs = max_num_seqs
self.max_model_len = max_model_len
self.num_lookahead_slots = num_lookahead_slots
self.delay_factor = delay_factor
self.chunked_prefill_enabled = enable_chunked_prefill
self.embedding_mode = embedding_mode
self.preemption_mode = preemption_mode
self.num_scheduler_steps = num_scheduler_steps
self.multi_step_stream_outputs = multi_step_stream_outputs
@ -1239,6 +1269,7 @@ class SpeculativeConfig:
ngram_prompt_lookup_min = 0
draft_model_config = ModelConfig(
model=speculative_model,
task=target_model_config.task,
tokenizer=target_model_config.tokenizer,
tokenizer_mode=target_model_config.tokenizer_mode,
trust_remote_code=target_model_config.trust_remote_code,

View File

@ -313,7 +313,7 @@ class Scheduler:
self.lora_config = lora_config
version = "selfattn"
if (self.scheduler_config.embedding_mode
if (self.scheduler_config.task == "embedding"
or self.cache_config.is_attention_free):
version = "placeholder"

View File

@ -3,7 +3,7 @@ import dataclasses
import json
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
Tuple, Type, Union, cast)
Tuple, Type, Union, cast, get_args)
import torch
@ -12,7 +12,7 @@ from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
DeviceConfig, EngineConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, ObservabilityConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TokenizerPoolConfig)
SpeculativeConfig, TaskOption, TokenizerPoolConfig)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
@ -84,6 +84,7 @@ class EngineArgs:
model: str = 'facebook/opt-125m'
served_model_name: Optional[Union[str, List[str]]] = None
tokenizer: Optional[str] = None
task: TaskOption = "auto"
skip_tokenizer_init: bool = False
tokenizer_mode: str = 'auto'
trust_remote_code: bool = False
@ -198,6 +199,15 @@ class EngineArgs:
type=str,
default=EngineArgs.model,
help='Name or path of the huggingface model to use.')
parser.add_argument(
'--task',
default=EngineArgs.task,
choices=get_args(TaskOption),
help='The task to use the model for. Each vLLM instance only '
'supports one task, even if the same model can be used for '
'multiple tasks. When the model only supports one task, "auto" '
'can be used to select it; otherwise, you must specify explicitly '
'which task to use.')
parser.add_argument(
'--tokenizer',
type=nullable_str,
@ -838,6 +848,7 @@ class EngineArgs:
def create_model_config(self) -> ModelConfig:
return ModelConfig(
model=self.model,
task=self.task,
# We know this is not None because we set it in __post_init__
tokenizer=cast(str, self.tokenizer),
tokenizer_mode=self.tokenizer_mode,
@ -1026,13 +1037,13 @@ class EngineArgs:
" please file an issue with detailed information.")
scheduler_config = SchedulerConfig(
task=model_config.task,
max_num_batched_tokens=self.max_num_batched_tokens,
max_num_seqs=self.max_num_seqs,
max_model_len=model_config.max_model_len,
num_lookahead_slots=num_lookahead_slots,
delay_factor=self.scheduler_delay_factor,
enable_chunked_prefill=self.enable_chunked_prefill,
embedding_mode=model_config.embedding_mode,
is_multimodal_model=model_config.is_multimodal_model,
preemption_mode=self.preemption_mode,
num_scheduler_steps=self.num_scheduler_steps,

View File

@ -344,7 +344,7 @@ class LLMEngine:
observability_config=self.observability_config,
)
if not self.model_config.embedding_mode:
if self.model_config.task != "embedding":
self._initialize_kv_caches()
# If usage stat is enabled, collect relevant info.
@ -1116,7 +1116,7 @@ class LLMEngine:
seq_group.metrics.model_execute_time = (
o.model_execute_time)
if self.model_config.embedding_mode:
if self.model_config.task == "embedding":
self._process_sequence_group_outputs(seq_group, output)
else:
self.output_processor.process_prompt_logprob(seq_group, output)
@ -1855,9 +1855,6 @@ class LLMEngine:
def is_encoder_decoder_model(self):
return self.input_preprocessor.is_encoder_decoder_model()
def is_embedding_model(self):
return self.model_config.is_embedding_model
def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
EncoderDecoderInputs]):
if self.model_config.is_multimodal_model:

View File

@ -8,7 +8,7 @@ from tqdm import tqdm
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score)
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.arg_utils import EngineArgs, TaskOption
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template,
@ -29,7 +29,7 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_kwargs, is_list_of
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
logger = init_logger(__name__)
@ -108,6 +108,12 @@ class LLM:
DEPRECATE_LEGACY: ClassVar[bool] = False
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
"""
A flag to toggle whether to deprecate positional arguments in
:meth:`LLM.__init__`.
"""
@classmethod
@contextmanager
def deprecate_legacy_api(cls):
@ -117,6 +123,13 @@ class LLM:
cls.DEPRECATE_LEGACY = False
@deprecate_args(
start_index=2, # Ignore self and model
is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
additional_message=(
"All positional arguments other than `model` will be "
"replaced with keyword arguments in an upcoming version."),
)
def __init__(
self,
model: str,
@ -139,6 +152,8 @@ class LLM:
disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False,
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
# After positional args are removed, move this right below `model`
task: TaskOption = "auto",
**kwargs,
) -> None:
'''
@ -153,6 +168,7 @@ class LLM:
engine_args = EngineArgs(
model=model,
task=task,
tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
skip_tokenizer_init=skip_tokenizer_init,
@ -316,10 +332,21 @@ class LLM:
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
"""
if self.llm_engine.model_config.embedding_mode:
raise ValueError(
task = self.llm_engine.model_config.task
if task != "generate":
messages = [
"LLM.generate() is only supported for (conditional) generation "
"models (XForCausalLM, XForConditionalGeneration).")
"models (XForCausalLM, XForConditionalGeneration).",
]
supported_tasks = self.llm_engine.model_config.supported_tasks
if "generate" in supported_tasks:
messages.append(
"Your model supports the 'generate' task, but is "
f"currently initialized for the '{task}' task. Please "
"initialize the model using `--task generate`.")
raise ValueError(" ".join(messages))
if prompt_token_ids is not None:
parsed_prompts = self._convert_v1_inputs(
@ -692,10 +719,18 @@ class LLM:
considered legacy and may be deprecated in the future. You should
instead pass them via the ``inputs`` parameter.
"""
if not self.llm_engine.model_config.embedding_mode:
raise ValueError(
"LLM.encode() is only supported for embedding models (XModel)."
)
task = self.llm_engine.model_config.task
if task != "embedding":
messages = ["LLM.encode() is only supported for embedding models."]
supported_tasks = self.llm_engine.model_config.supported_tasks
if "embedding" in supported_tasks:
messages.append(
"Your model supports the 'embedding' task, but is "
f"currently initialized for the '{task}' task. Please "
"initialize the model using `--task embedding`.")
raise ValueError(" ".join(messages))
if prompt_token_ids is not None:
parsed_prompts = self._convert_v1_inputs(
@ -905,6 +940,3 @@ class LLM:
def _is_encoder_decoder_model(self):
return self.llm_engine.is_encoder_decoder_model()
def _is_embedding_model(self):
return self.llm_engine.is_embedding_model()

View File

@ -83,7 +83,8 @@ class OpenAIServingEmbedding(OpenAIServing):
lora_modules=None,
prompt_adapters=None,
request_logger=request_logger)
self._enabled = self._check_embedding_mode(model_config.embedding_mode)
self._enabled = self._check_embedding_mode(
model_config.task == "embedding")
async def create_embedding(
self,

View File

@ -1034,10 +1034,54 @@ def identity(value: T) -> T:
F = TypeVar('F', bound=Callable[..., Any])
def deprecate_args(
start_index: int,
is_deprecated: Union[bool, Callable[[], bool]] = True,
additional_message: Optional[str] = None,
) -> Callable[[F], F]:
if not callable(is_deprecated):
is_deprecated = partial(identity, is_deprecated)
def wrapper(fn: F) -> F:
params = inspect.signature(fn).parameters
pos_types = (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
pos_kws = [
kw for kw, param in params.items() if param.kind in pos_types
]
@wraps(fn)
def inner(*args, **kwargs):
if is_deprecated():
deprecated_args = pos_kws[start_index:len(args)]
if deprecated_args:
msg = (
f"The positional arguments {deprecated_args} are "
"deprecated and will be removed in a future update.")
if additional_message is not None:
msg += f" {additional_message}"
warnings.warn(
DeprecationWarning(msg),
stacklevel=3, # The inner function takes up one level
)
return fn(*args, **kwargs)
return inner # type: ignore
return wrapper
def deprecate_kwargs(
*kws: str,
is_deprecated: Union[bool, Callable[[], bool]] = True,
additional_message: Optional[str] = None) -> Callable[[F], F]:
*kws: str,
is_deprecated: Union[bool, Callable[[], bool]] = True,
additional_message: Optional[str] = None,
) -> Callable[[F], F]:
deprecated_kws = set(kws)
if not callable(is_deprecated):

View File

@ -92,7 +92,7 @@ class Worker(LocalOrDistributedWorkerBase):
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
if model_runner_cls is not None:
ModelRunnerClass = model_runner_cls
elif self._is_embedding_model():
elif model_config.task == "embedding":
ModelRunnerClass = EmbeddingModelRunner
elif self._is_encoder_decoder_model():
ModelRunnerClass = EncoderDecoderModelRunner
@ -147,9 +147,6 @@ class Worker(LocalOrDistributedWorkerBase):
def _is_encoder_decoder_model(self):
return self.model_config.is_encoder_decoder_model
def _is_embedding_model(self):
return self.model_config.is_embedding_model
def init_device(self) -> None:
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until