[Model] Add user-configurable task for models that support both generation and embedding (#9424)
This commit is contained in:
parent
7dbe738d65
commit
051eaf6db3
@ -294,6 +294,10 @@ Text Embedding
|
|||||||
-
|
-
|
||||||
- ✅︎
|
- ✅︎
|
||||||
|
|
||||||
|
.. important::
|
||||||
|
Some model architectures support both generation and embedding tasks.
|
||||||
|
In this case, you have to pass :code:`--task embedding` to run the model in embedding mode.
|
||||||
|
|
||||||
Reward Modeling
|
Reward Modeling
|
||||||
---------------
|
---------------
|
||||||
|
|
||||||
@ -482,6 +486,10 @@ Multimodal Embedding
|
|||||||
- 🚧
|
- 🚧
|
||||||
- ✅︎
|
- ✅︎
|
||||||
|
|
||||||
|
.. important::
|
||||||
|
Some model architectures support both generation and embedding tasks.
|
||||||
|
In this case, you have to pass :code:`--task embedding` to run the model in embedding mode.
|
||||||
|
|
||||||
----
|
----
|
||||||
|
|
||||||
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
|
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
|
||||||
|
@ -181,8 +181,8 @@ Below is an example on how to launch the same ``microsoft/Phi-3.5-vision-instruc
|
|||||||
|
|
||||||
.. code-block:: bash
|
.. code-block:: bash
|
||||||
|
|
||||||
vllm serve microsoft/Phi-3.5-vision-instruct --max-model-len 4096 \
|
vllm serve microsoft/Phi-3.5-vision-instruct --task generate \
|
||||||
--trust-remote-code --limit-mm-per-prompt image=2
|
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2
|
||||||
|
|
||||||
.. important::
|
.. important::
|
||||||
Since OpenAI Vision API is based on `Chat Completions <https://platform.openai.com/docs/api-reference/chat>`_ API,
|
Since OpenAI Vision API is based on `Chat Completions <https://platform.openai.com/docs/api-reference/chat>`_ API,
|
||||||
|
@ -7,6 +7,7 @@ prompt = "<|image_1|> Represent the given image with the following question: Wha
|
|||||||
# Create an LLM.
|
# Create an LLM.
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
model="TIGER-Lab/VLM2Vec-Full",
|
model="TIGER-Lab/VLM2Vec-Full",
|
||||||
|
task="embedding",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
max_num_seqs=2,
|
max_num_seqs=2,
|
||||||
|
@ -7,8 +7,8 @@ Launch the vLLM server with the following command:
|
|||||||
vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja
|
vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja
|
||||||
|
|
||||||
(multi-image inference with Phi-3.5-vision-instruct)
|
(multi-image inference with Phi-3.5-vision-instruct)
|
||||||
vllm serve microsoft/Phi-3.5-vision-instruct --max-model-len 4096 \
|
vllm serve microsoft/Phi-3.5-vision-instruct --task generate \
|
||||||
--trust-remote-code --limit-mm-per-prompt image=2
|
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2
|
||||||
|
|
||||||
(audio inference with Ultravox)
|
(audio inference with Ultravox)
|
||||||
vllm serve fixie-ai/ultravox-v0_3 --max-model-len 4096
|
vllm serve fixie-ai/ultravox-v0_3 --max-model-len 4096
|
||||||
|
@ -25,7 +25,7 @@ from tests.models.utils import (TokensTextLogprobs,
|
|||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.assets.image import ImageAsset
|
from vllm.assets.image import ImageAsset
|
||||||
from vllm.assets.video import VideoAsset
|
from vllm.assets.video import VideoAsset
|
||||||
from vllm.config import TokenizerPoolConfig
|
from vllm.config import TaskOption, TokenizerPoolConfig
|
||||||
from vllm.connections import global_http_connection
|
from vllm.connections import global_http_connection
|
||||||
from vllm.distributed import (destroy_distributed_environment,
|
from vllm.distributed import (destroy_distributed_environment,
|
||||||
destroy_model_parallel,
|
destroy_model_parallel,
|
||||||
@ -619,6 +619,7 @@ class VllmRunner:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
|
task: TaskOption = "auto",
|
||||||
tokenizer_name: Optional[str] = None,
|
tokenizer_name: Optional[str] = None,
|
||||||
# Use smaller max model length, otherwise bigger model cannot run due
|
# Use smaller max model length, otherwise bigger model cannot run due
|
||||||
# to kv cache size limit.
|
# to kv cache size limit.
|
||||||
@ -634,6 +635,7 @@ class VllmRunner:
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.model = LLM(
|
self.model = LLM(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
|
task=task,
|
||||||
tokenizer=tokenizer_name,
|
tokenizer=tokenizer_name,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
@ -33,7 +33,8 @@ def test_simple():
|
|||||||
num_seq_group = 4
|
num_seq_group = 4
|
||||||
max_model_len = 16
|
max_model_len = 16
|
||||||
max_num_batched_tokens = 64
|
max_num_batched_tokens = 64
|
||||||
scheduler_config = SchedulerConfig(max_num_batched_tokens,
|
scheduler_config = SchedulerConfig("generate",
|
||||||
|
max_num_batched_tokens,
|
||||||
num_seq_group,
|
num_seq_group,
|
||||||
max_model_len,
|
max_model_len,
|
||||||
enable_chunked_prefill=True)
|
enable_chunked_prefill=True)
|
||||||
@ -78,6 +79,7 @@ def test_chunk():
|
|||||||
max_model_len = 80
|
max_model_len = 80
|
||||||
max_num_batched_tokens = 64
|
max_num_batched_tokens = 64
|
||||||
scheduler_config = SchedulerConfig(
|
scheduler_config = SchedulerConfig(
|
||||||
|
"generate",
|
||||||
max_num_batched_tokens,
|
max_num_batched_tokens,
|
||||||
max_seqs,
|
max_seqs,
|
||||||
max_model_len,
|
max_model_len,
|
||||||
@ -126,6 +128,7 @@ def test_complex():
|
|||||||
max_model_len = 80
|
max_model_len = 80
|
||||||
max_num_batched_tokens = 64
|
max_num_batched_tokens = 64
|
||||||
scheduler_config = SchedulerConfig(
|
scheduler_config = SchedulerConfig(
|
||||||
|
"generate",
|
||||||
max_num_batched_tokens,
|
max_num_batched_tokens,
|
||||||
max_seqs,
|
max_seqs,
|
||||||
max_model_len,
|
max_model_len,
|
||||||
@ -196,6 +199,7 @@ def test_maximal_decoding():
|
|||||||
max_model_len = 8
|
max_model_len = 8
|
||||||
max_num_batched_tokens = 2
|
max_num_batched_tokens = 2
|
||||||
scheduler_config = SchedulerConfig(
|
scheduler_config = SchedulerConfig(
|
||||||
|
"generate",
|
||||||
max_num_batched_tokens,
|
max_num_batched_tokens,
|
||||||
max_seqs,
|
max_seqs,
|
||||||
max_model_len,
|
max_model_len,
|
||||||
@ -289,6 +293,7 @@ def test_prompt_limit():
|
|||||||
max_model_len = 64
|
max_model_len = 64
|
||||||
max_num_batched_tokens = 32
|
max_num_batched_tokens = 32
|
||||||
scheduler_config = SchedulerConfig(
|
scheduler_config = SchedulerConfig(
|
||||||
|
"generate",
|
||||||
max_num_batched_tokens,
|
max_num_batched_tokens,
|
||||||
max_seqs,
|
max_seqs,
|
||||||
max_model_len,
|
max_model_len,
|
||||||
@ -321,7 +326,8 @@ def test_prompt_limit_exceed():
|
|||||||
max_seqs = 64
|
max_seqs = 64
|
||||||
max_model_len = 32
|
max_model_len = 32
|
||||||
max_num_batched_tokens = 64
|
max_num_batched_tokens = 64
|
||||||
scheduler_config = SchedulerConfig(max_num_batched_tokens,
|
scheduler_config = SchedulerConfig("generate",
|
||||||
|
max_num_batched_tokens,
|
||||||
max_seqs,
|
max_seqs,
|
||||||
max_model_len,
|
max_model_len,
|
||||||
enable_chunked_prefill=True)
|
enable_chunked_prefill=True)
|
||||||
@ -348,6 +354,7 @@ def test_swap():
|
|||||||
max_model_len = 200
|
max_model_len = 200
|
||||||
max_num_batched_tokens = 30
|
max_num_batched_tokens = 30
|
||||||
scheduler_config = SchedulerConfig(
|
scheduler_config = SchedulerConfig(
|
||||||
|
"generate",
|
||||||
max_num_batched_tokens,
|
max_num_batched_tokens,
|
||||||
max_seqs,
|
max_seqs,
|
||||||
max_model_len,
|
max_model_len,
|
||||||
@ -404,6 +411,7 @@ def test_running_prefill_prioritized_over_swap():
|
|||||||
max_model_len = 200
|
max_model_len = 200
|
||||||
max_num_batched_tokens = 30
|
max_num_batched_tokens = 30
|
||||||
scheduler_config = SchedulerConfig(
|
scheduler_config = SchedulerConfig(
|
||||||
|
"generate",
|
||||||
max_num_batched_tokens,
|
max_num_batched_tokens,
|
||||||
max_seqs,
|
max_seqs,
|
||||||
max_model_len,
|
max_model_len,
|
||||||
@ -498,6 +506,7 @@ def test_chunked_prefill_preempt():
|
|||||||
max_model_len = 200
|
max_model_len = 200
|
||||||
max_num_batched_tokens = 30
|
max_num_batched_tokens = 30
|
||||||
scheduler_config = SchedulerConfig(
|
scheduler_config = SchedulerConfig(
|
||||||
|
"generate",
|
||||||
max_num_batched_tokens,
|
max_num_batched_tokens,
|
||||||
max_seqs,
|
max_seqs,
|
||||||
max_model_len,
|
max_model_len,
|
||||||
@ -563,6 +572,7 @@ def test_chunked_prefill_max_seqs():
|
|||||||
max_model_len = 80
|
max_model_len = 80
|
||||||
max_num_batched_tokens = 64
|
max_num_batched_tokens = 64
|
||||||
scheduler_config = SchedulerConfig(
|
scheduler_config = SchedulerConfig(
|
||||||
|
"generate",
|
||||||
max_num_batched_tokens,
|
max_num_batched_tokens,
|
||||||
max_seqs,
|
max_seqs,
|
||||||
max_model_len,
|
max_model_len,
|
||||||
@ -617,6 +627,7 @@ def test_perfix_caching():
|
|||||||
max_model_len = 80
|
max_model_len = 80
|
||||||
max_num_batched_tokens = 64
|
max_num_batched_tokens = 64
|
||||||
scheduler_config = SchedulerConfig(
|
scheduler_config = SchedulerConfig(
|
||||||
|
"generate",
|
||||||
max_num_batched_tokens,
|
max_num_batched_tokens,
|
||||||
max_seqs,
|
max_seqs,
|
||||||
max_model_len,
|
max_model_len,
|
||||||
|
@ -20,9 +20,10 @@ from .utils import (append_new_token, append_new_token_seq_group,
|
|||||||
def test_scheduler_add_seq_group():
|
def test_scheduler_add_seq_group():
|
||||||
block_size = 4
|
block_size = 4
|
||||||
scheduler_config = SchedulerConfig(
|
scheduler_config = SchedulerConfig(
|
||||||
100,
|
"generate",
|
||||||
64,
|
max_num_batched_tokens=100,
|
||||||
1,
|
max_num_seqs=64,
|
||||||
|
max_model_len=1,
|
||||||
)
|
)
|
||||||
cache_config = CacheConfig(block_size, 1.0, 1, cache_dtype="auto")
|
cache_config = CacheConfig(block_size, 1.0, 1, cache_dtype="auto")
|
||||||
cache_config.num_cpu_blocks = 4
|
cache_config.num_cpu_blocks = 4
|
||||||
@ -42,9 +43,10 @@ def test_scheduler_add_seq_group():
|
|||||||
def test_scheduler_abort_seq_group():
|
def test_scheduler_abort_seq_group():
|
||||||
block_size = 4
|
block_size = 4
|
||||||
scheduler_config = SchedulerConfig(
|
scheduler_config = SchedulerConfig(
|
||||||
100,
|
"generate",
|
||||||
64,
|
max_num_batched_tokens=100,
|
||||||
1,
|
max_num_seqs=64,
|
||||||
|
max_model_len=1,
|
||||||
)
|
)
|
||||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||||
cache_config.num_cpu_blocks = 4
|
cache_config.num_cpu_blocks = 4
|
||||||
@ -70,9 +72,10 @@ def test_scheduler_schedule_simple():
|
|||||||
num_seq_group = 4
|
num_seq_group = 4
|
||||||
max_model_len = 16
|
max_model_len = 16
|
||||||
scheduler_config = SchedulerConfig(
|
scheduler_config = SchedulerConfig(
|
||||||
64,
|
"generate",
|
||||||
num_seq_group,
|
max_num_batched_tokens=64,
|
||||||
max_model_len,
|
max_num_seqs=num_seq_group,
|
||||||
|
max_model_len=max_model_len,
|
||||||
)
|
)
|
||||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||||
cache_config.num_cpu_blocks = 8
|
cache_config.num_cpu_blocks = 8
|
||||||
@ -114,9 +117,10 @@ def test_scheduler_prefill_prioritized():
|
|||||||
max_model_len = 30
|
max_model_len = 30
|
||||||
max_batched_num_tokens = 30
|
max_batched_num_tokens = 30
|
||||||
scheduler_config = SchedulerConfig(
|
scheduler_config = SchedulerConfig(
|
||||||
max_batched_num_tokens,
|
"generate",
|
||||||
2,
|
max_num_batched_tokens=max_batched_num_tokens,
|
||||||
max_model_len,
|
max_num_seqs=2,
|
||||||
|
max_model_len=max_model_len,
|
||||||
)
|
)
|
||||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||||
cache_config.num_cpu_blocks = 16
|
cache_config.num_cpu_blocks = 16
|
||||||
@ -145,9 +149,10 @@ def test_scheduler_schedule_preempt_abort():
|
|||||||
block_size = 4
|
block_size = 4
|
||||||
max_model_len = 16
|
max_model_len = 16
|
||||||
scheduler_config = SchedulerConfig(
|
scheduler_config = SchedulerConfig(
|
||||||
64,
|
"generate",
|
||||||
2,
|
max_num_batched_tokens=64,
|
||||||
max_model_len,
|
max_num_seqs=2,
|
||||||
|
max_model_len=max_model_len,
|
||||||
)
|
)
|
||||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||||
cache_config.num_cpu_blocks = 2
|
cache_config.num_cpu_blocks = 2
|
||||||
@ -204,9 +209,10 @@ def test_scheduler_max_seqs():
|
|||||||
max_seq_group = 2
|
max_seq_group = 2
|
||||||
max_model_len = 16
|
max_model_len = 16
|
||||||
scheduler_config = SchedulerConfig(
|
scheduler_config = SchedulerConfig(
|
||||||
64,
|
"generate",
|
||||||
max_seq_group,
|
max_num_batched_tokens=64,
|
||||||
max_model_len,
|
max_num_seqs=max_seq_group,
|
||||||
|
max_model_len=max_model_len,
|
||||||
)
|
)
|
||||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||||
cache_config.num_cpu_blocks = 8
|
cache_config.num_cpu_blocks = 8
|
||||||
@ -248,9 +254,10 @@ def test_scheduler_max_seqs():
|
|||||||
def test_scheduler_delay_factor():
|
def test_scheduler_delay_factor():
|
||||||
block_size = 4
|
block_size = 4
|
||||||
scheduler_config = SchedulerConfig(
|
scheduler_config = SchedulerConfig(
|
||||||
100,
|
"generate",
|
||||||
64,
|
max_num_batched_tokens=100,
|
||||||
16,
|
max_num_seqs=64,
|
||||||
|
max_model_len=16,
|
||||||
delay_factor=0.5,
|
delay_factor=0.5,
|
||||||
)
|
)
|
||||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||||
@ -350,9 +357,10 @@ def initialize_scheduler(
|
|||||||
):
|
):
|
||||||
block_size = block_size
|
block_size = block_size
|
||||||
scheduler_config = SchedulerConfig(
|
scheduler_config = SchedulerConfig(
|
||||||
max_token_budget,
|
"generate",
|
||||||
max_num_seqs,
|
max_num_batched_tokens=max_token_budget,
|
||||||
max_model_len,
|
max_num_seqs=max_num_seqs,
|
||||||
|
max_model_len=max_model_len,
|
||||||
)
|
)
|
||||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||||
cache_config.num_cpu_blocks = num_cpu_blocks
|
cache_config.num_cpu_blocks = num_cpu_blocks
|
||||||
|
@ -36,7 +36,12 @@ def test_scheduler_schedule_simple_encoder_decoder():
|
|||||||
block_size = 4
|
block_size = 4
|
||||||
num_seq_group = 4
|
num_seq_group = 4
|
||||||
max_model_len = 16
|
max_model_len = 16
|
||||||
scheduler_config = SchedulerConfig(64, num_seq_group, max_model_len)
|
scheduler_config = SchedulerConfig(
|
||||||
|
task="generate",
|
||||||
|
max_num_batched_tokens=64,
|
||||||
|
max_num_seqs=num_seq_group,
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
)
|
||||||
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
cache_config = CacheConfig(block_size, 1.0, 1, "auto")
|
||||||
cache_config.num_cpu_blocks = 16 # enc and dec prompts per seq_group
|
cache_config.num_cpu_blocks = 16 # enc and dec prompts per seq_group
|
||||||
cache_config.num_gpu_blocks = 16 # enc and dec prompts per seq_group
|
cache_config.num_gpu_blocks = 16 # enc and dec prompts per seq_group
|
||||||
|
@ -11,6 +11,7 @@ from typing import List, Literal, NamedTuple, Optional
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from vllm.config import TaskOption
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from ..utils import compare_two_settings, fork_new_process_for_each_test
|
from ..utils import compare_two_settings, fork_new_process_for_each_test
|
||||||
@ -31,6 +32,7 @@ class ParallelSetup(NamedTuple):
|
|||||||
class PPTestSettings:
|
class PPTestSettings:
|
||||||
parallel_setups: List[ParallelSetup]
|
parallel_setups: List[ParallelSetup]
|
||||||
distributed_backends: List[str]
|
distributed_backends: List[str]
|
||||||
|
task: TaskOption
|
||||||
trust_remote_code: bool
|
trust_remote_code: bool
|
||||||
tokenizer_mode: Optional[str]
|
tokenizer_mode: Optional[str]
|
||||||
|
|
||||||
@ -39,6 +41,7 @@ class PPTestSettings:
|
|||||||
*,
|
*,
|
||||||
tp_base: int = 1,
|
tp_base: int = 1,
|
||||||
pp_base: int = 2,
|
pp_base: int = 2,
|
||||||
|
task: TaskOption = "auto",
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
tokenizer_mode: Optional[str] = None,
|
tokenizer_mode: Optional[str] = None,
|
||||||
):
|
):
|
||||||
@ -66,6 +69,7 @@ class PPTestSettings:
|
|||||||
chunked_prefill=False),
|
chunked_prefill=False),
|
||||||
],
|
],
|
||||||
distributed_backends=["mp", "ray"],
|
distributed_backends=["mp", "ray"],
|
||||||
|
task=task,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
tokenizer_mode=tokenizer_mode,
|
tokenizer_mode=tokenizer_mode,
|
||||||
)
|
)
|
||||||
@ -75,6 +79,7 @@ class PPTestSettings:
|
|||||||
*,
|
*,
|
||||||
tp_base: int = 1,
|
tp_base: int = 1,
|
||||||
pp_base: int = 2,
|
pp_base: int = 2,
|
||||||
|
task: TaskOption = "auto",
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
tokenizer_mode: Optional[str] = None,
|
tokenizer_mode: Optional[str] = None,
|
||||||
):
|
):
|
||||||
@ -86,6 +91,7 @@ class PPTestSettings:
|
|||||||
chunked_prefill=False),
|
chunked_prefill=False),
|
||||||
],
|
],
|
||||||
distributed_backends=["mp"],
|
distributed_backends=["mp"],
|
||||||
|
task=task,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
tokenizer_mode=tokenizer_mode,
|
tokenizer_mode=tokenizer_mode,
|
||||||
)
|
)
|
||||||
@ -94,7 +100,7 @@ class PPTestSettings:
|
|||||||
for parallel_setup in self.parallel_setups:
|
for parallel_setup in self.parallel_setups:
|
||||||
for distributed_backend in self.distributed_backends:
|
for distributed_backend in self.distributed_backends:
|
||||||
yield (model_name, parallel_setup, distributed_backend,
|
yield (model_name, parallel_setup, distributed_backend,
|
||||||
self.trust_remote_code, self.tokenizer_mode)
|
self.task, self.trust_remote_code, self.tokenizer_mode)
|
||||||
|
|
||||||
|
|
||||||
# NOTE: You can adjust tp_base and/or pp_base locally to fit the model in GPU
|
# NOTE: You can adjust tp_base and/or pp_base locally to fit the model in GPU
|
||||||
@ -213,6 +219,7 @@ def _compare_tp(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
parallel_setup: ParallelSetup,
|
parallel_setup: ParallelSetup,
|
||||||
distributed_backend: str,
|
distributed_backend: str,
|
||||||
|
task: TaskOption,
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
tokenizer_mode: Optional[str],
|
tokenizer_mode: Optional[str],
|
||||||
num_gpus_available: int,
|
num_gpus_available: int,
|
||||||
@ -240,6 +247,8 @@ def _compare_tp(
|
|||||||
common_args.append("--enable-chunked-prefill")
|
common_args.append("--enable-chunked-prefill")
|
||||||
if eager_mode:
|
if eager_mode:
|
||||||
common_args.append("--enforce-eager")
|
common_args.append("--enforce-eager")
|
||||||
|
if task != "auto":
|
||||||
|
common_args.extend(["--task", task])
|
||||||
if trust_remote_code:
|
if trust_remote_code:
|
||||||
common_args.append("--trust-remote-code")
|
common_args.append("--trust-remote-code")
|
||||||
if tokenizer_mode:
|
if tokenizer_mode:
|
||||||
@ -297,7 +306,7 @@ def _compare_tp(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("model_name", "parallel_setup", "distributed_backend",
|
("model_name", "parallel_setup", "distributed_backend", "task",
|
||||||
"trust_remote_code", "tokenizer_mode"),
|
"trust_remote_code", "tokenizer_mode"),
|
||||||
[
|
[
|
||||||
params for model_name, settings in GENERATION_MODEL_SETTINGS.items()
|
params for model_name, settings in GENERATION_MODEL_SETTINGS.items()
|
||||||
@ -310,6 +319,7 @@ def test_tp_language_generation(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
parallel_setup: ParallelSetup,
|
parallel_setup: ParallelSetup,
|
||||||
distributed_backend: str,
|
distributed_backend: str,
|
||||||
|
task: TaskOption,
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
tokenizer_mode: Optional[str],
|
tokenizer_mode: Optional[str],
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
@ -317,6 +327,7 @@ def test_tp_language_generation(
|
|||||||
_compare_tp(model_name,
|
_compare_tp(model_name,
|
||||||
parallel_setup,
|
parallel_setup,
|
||||||
distributed_backend,
|
distributed_backend,
|
||||||
|
task,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
tokenizer_mode,
|
tokenizer_mode,
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
@ -324,7 +335,7 @@ def test_tp_language_generation(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("model_name", "parallel_setup", "distributed_backend",
|
("model_name", "parallel_setup", "distributed_backend", "task",
|
||||||
"trust_remote_code", "tokenizer_mode"),
|
"trust_remote_code", "tokenizer_mode"),
|
||||||
[
|
[
|
||||||
params for model_name, settings in EMBEDDING_MODEL_SETTINGS.items()
|
params for model_name, settings in EMBEDDING_MODEL_SETTINGS.items()
|
||||||
@ -337,6 +348,7 @@ def test_tp_language_embedding(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
parallel_setup: ParallelSetup,
|
parallel_setup: ParallelSetup,
|
||||||
distributed_backend: str,
|
distributed_backend: str,
|
||||||
|
task: TaskOption,
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
tokenizer_mode: Optional[str],
|
tokenizer_mode: Optional[str],
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
@ -344,6 +356,7 @@ def test_tp_language_embedding(
|
|||||||
_compare_tp(model_name,
|
_compare_tp(model_name,
|
||||||
parallel_setup,
|
parallel_setup,
|
||||||
distributed_backend,
|
distributed_backend,
|
||||||
|
task,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
tokenizer_mode,
|
tokenizer_mode,
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
@ -351,7 +364,7 @@ def test_tp_language_embedding(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("model_name", "parallel_setup", "distributed_backend",
|
("model_name", "parallel_setup", "distributed_backend", "task",
|
||||||
"trust_remote_code", "tokenizer_mode"),
|
"trust_remote_code", "tokenizer_mode"),
|
||||||
[
|
[
|
||||||
params for model_name, settings in MULTIMODAL_MODEL_SETTINGS.items()
|
params for model_name, settings in MULTIMODAL_MODEL_SETTINGS.items()
|
||||||
@ -364,6 +377,7 @@ def test_tp_multimodal_generation(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
parallel_setup: ParallelSetup,
|
parallel_setup: ParallelSetup,
|
||||||
distributed_backend: str,
|
distributed_backend: str,
|
||||||
|
task: TaskOption,
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
tokenizer_mode: Optional[str],
|
tokenizer_mode: Optional[str],
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
@ -371,6 +385,7 @@ def test_tp_multimodal_generation(
|
|||||||
_compare_tp(model_name,
|
_compare_tp(model_name,
|
||||||
parallel_setup,
|
parallel_setup,
|
||||||
distributed_backend,
|
distributed_backend,
|
||||||
|
task,
|
||||||
trust_remote_code,
|
trust_remote_code,
|
||||||
tokenizer_mode,
|
tokenizer_mode,
|
||||||
num_gpus_available,
|
num_gpus_available,
|
||||||
|
92
tests/entrypoints/llm/test_chat.py
Normal file
92
tests/entrypoints/llm/test_chat.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm import LLM
|
||||||
|
|
||||||
|
from ..openai.test_vision import TEST_IMAGE_URLS
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat():
|
||||||
|
llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
|
||||||
|
|
||||||
|
prompt1 = "Explain the concept of entropy."
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt1
|
||||||
|
},
|
||||||
|
]
|
||||||
|
outputs = llm.chat(messages)
|
||||||
|
assert len(outputs) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_multi_chat():
|
||||||
|
llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
|
||||||
|
|
||||||
|
prompt1 = "Explain the concept of entropy."
|
||||||
|
prompt2 = "Explain what among us is."
|
||||||
|
|
||||||
|
conversation1 = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt1
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
conversation2 = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt2
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
messages = [conversation1, conversation2]
|
||||||
|
|
||||||
|
outputs = llm.chat(messages)
|
||||||
|
assert len(outputs) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("image_urls",
|
||||||
|
[[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]])
|
||||||
|
def test_chat_multi_image(image_urls: List[str]):
|
||||||
|
llm = LLM(
|
||||||
|
model="microsoft/Phi-3.5-vision-instruct",
|
||||||
|
dtype="bfloat16",
|
||||||
|
max_model_len=4096,
|
||||||
|
max_num_seqs=5,
|
||||||
|
enforce_eager=True,
|
||||||
|
trust_remote_code=True,
|
||||||
|
limit_mm_per_prompt={"image": 2},
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [{
|
||||||
|
"role":
|
||||||
|
"user",
|
||||||
|
"content": [
|
||||||
|
*({
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": image_url
|
||||||
|
}
|
||||||
|
} for image_url in image_urls),
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What's in this image?"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}]
|
||||||
|
outputs = llm.chat(messages)
|
||||||
|
assert len(outputs) >= 0
|
@ -6,7 +6,6 @@ import pytest
|
|||||||
from vllm import LLM, RequestOutput, SamplingParams
|
from vllm import LLM, RequestOutput, SamplingParams
|
||||||
|
|
||||||
from ...conftest import cleanup
|
from ...conftest import cleanup
|
||||||
from ..openai.test_vision import TEST_IMAGE_URLS
|
|
||||||
|
|
||||||
MODEL_NAME = "facebook/opt-125m"
|
MODEL_NAME = "facebook/opt-125m"
|
||||||
|
|
||||||
@ -104,90 +103,3 @@ def test_multiple_sampling_params(llm: LLM):
|
|||||||
# sampling_params is None, default params should be applied
|
# sampling_params is None, default params should be applied
|
||||||
outputs = llm.generate(PROMPTS, sampling_params=None)
|
outputs = llm.generate(PROMPTS, sampling_params=None)
|
||||||
assert len(PROMPTS) == len(outputs)
|
assert len(PROMPTS) == len(outputs)
|
||||||
|
|
||||||
|
|
||||||
def test_chat():
|
|
||||||
|
|
||||||
llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
|
|
||||||
|
|
||||||
prompt1 = "Explain the concept of entropy."
|
|
||||||
messages = [
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": "You are a helpful assistant"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": prompt1
|
|
||||||
},
|
|
||||||
]
|
|
||||||
outputs = llm.chat(messages)
|
|
||||||
assert len(outputs) == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_multi_chat():
|
|
||||||
|
|
||||||
llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
|
|
||||||
|
|
||||||
prompt1 = "Explain the concept of entropy."
|
|
||||||
prompt2 = "Explain what among us is."
|
|
||||||
|
|
||||||
conversation1 = [
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": "You are a helpful assistant"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": prompt1
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
conversation2 = [
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": "You are a helpful assistant"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": prompt2
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
messages = [conversation1, conversation2]
|
|
||||||
|
|
||||||
outputs = llm.chat(messages)
|
|
||||||
assert len(outputs) == 2
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("image_urls",
|
|
||||||
[[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]])
|
|
||||||
def test_chat_multi_image(image_urls: List[str]):
|
|
||||||
llm = LLM(
|
|
||||||
model="microsoft/Phi-3.5-vision-instruct",
|
|
||||||
dtype="bfloat16",
|
|
||||||
max_model_len=4096,
|
|
||||||
max_num_seqs=5,
|
|
||||||
enforce_eager=True,
|
|
||||||
trust_remote_code=True,
|
|
||||||
limit_mm_per_prompt={"image": 2},
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = [{
|
|
||||||
"role":
|
|
||||||
"user",
|
|
||||||
"content": [
|
|
||||||
*({
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": image_url
|
|
||||||
}
|
|
||||||
} for image_url in image_urls),
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": "What's in this image?"
|
|
||||||
},
|
|
||||||
],
|
|
||||||
}]
|
|
||||||
outputs = llm.chat(messages)
|
|
||||||
assert len(outputs) >= 0
|
|
||||||
|
22
tests/entrypoints/llm/test_init.py
Normal file
22
tests/entrypoints/llm/test_init.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm import LLM
|
||||||
|
|
||||||
|
from ...utils import error_on_warning
|
||||||
|
|
||||||
|
MODEL_NAME = "facebook/opt-125m"
|
||||||
|
|
||||||
|
|
||||||
|
def test_pos_args_deprecated():
|
||||||
|
with error_on_warning(DeprecationWarning):
|
||||||
|
LLM(model=MODEL_NAME, tokenizer=MODEL_NAME)
|
||||||
|
|
||||||
|
with error_on_warning(DeprecationWarning):
|
||||||
|
LLM(MODEL_NAME, tokenizer=MODEL_NAME)
|
||||||
|
|
||||||
|
with pytest.warns(DeprecationWarning, match="'tokenizer'"):
|
||||||
|
LLM(MODEL_NAME, MODEL_NAME)
|
||||||
|
|
||||||
|
with pytest.warns(DeprecationWarning,
|
||||||
|
match="'tokenizer', 'tokenizer_mode'"):
|
||||||
|
LLM(MODEL_NAME, MODEL_NAME, "auto")
|
@ -22,12 +22,12 @@ class MockHFConfig:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MockModelConfig:
|
class MockModelConfig:
|
||||||
|
task = "generate"
|
||||||
tokenizer = MODEL_NAME
|
tokenizer = MODEL_NAME
|
||||||
trust_remote_code = False
|
trust_remote_code = False
|
||||||
tokenizer_mode = "auto"
|
tokenizer_mode = "auto"
|
||||||
max_model_len = 100
|
max_model_len = 100
|
||||||
tokenizer_revision = None
|
tokenizer_revision = None
|
||||||
embedding_mode = False
|
|
||||||
multimodal_config = MultiModalConfig()
|
multimodal_config = MultiModalConfig()
|
||||||
hf_config = MockHFConfig()
|
hf_config = MockHFConfig()
|
||||||
|
|
||||||
|
@ -23,6 +23,8 @@ TEST_IMAGE_URLS = [
|
|||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def server():
|
def server():
|
||||||
args = [
|
args = [
|
||||||
|
"--task",
|
||||||
|
"generate",
|
||||||
"--dtype",
|
"--dtype",
|
||||||
"bfloat16",
|
"bfloat16",
|
||||||
"--max-model-len",
|
"--max-model-len",
|
||||||
|
@ -18,7 +18,8 @@ PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
|
|||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def phi3v_model_config():
|
def phi3v_model_config():
|
||||||
return ModelConfig(PHI3V_MODEL_ID,
|
return ModelConfig(PHI3V_MODEL_ID,
|
||||||
PHI3V_MODEL_ID,
|
task="generate",
|
||||||
|
tokenizer=PHI3V_MODEL_ID,
|
||||||
tokenizer_mode="auto",
|
tokenizer_mode="auto",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
dtype="bfloat16",
|
dtype="bfloat16",
|
||||||
|
@ -15,7 +15,8 @@ def test_worker_apply_lora(sql_lora_files):
|
|||||||
worker = Worker(
|
worker = Worker(
|
||||||
model_config=ModelConfig(
|
model_config=ModelConfig(
|
||||||
"meta-llama/Llama-2-7b-hf",
|
"meta-llama/Llama-2-7b-hf",
|
||||||
"meta-llama/Llama-2-7b-hf",
|
task="auto",
|
||||||
|
tokenizer="meta-llama/Llama-2-7b-hf",
|
||||||
tokenizer_mode="auto",
|
tokenizer_mode="auto",
|
||||||
trust_remote_code=False,
|
trust_remote_code=False,
|
||||||
seed=0,
|
seed=0,
|
||||||
@ -27,7 +28,7 @@ def test_worker_apply_lora(sql_lora_files):
|
|||||||
load_format="dummy",
|
load_format="dummy",
|
||||||
),
|
),
|
||||||
parallel_config=ParallelConfig(1, 1, False),
|
parallel_config=ParallelConfig(1, 1, False),
|
||||||
scheduler_config=SchedulerConfig(32, 32, 32),
|
scheduler_config=SchedulerConfig("generate", 32, 32, 32),
|
||||||
device_config=DeviceConfig("cuda"),
|
device_config=DeviceConfig("cuda"),
|
||||||
cache_config=CacheConfig(block_size=16,
|
cache_config=CacheConfig(block_size=16,
|
||||||
gpu_memory_utilization=1.,
|
gpu_memory_utilization=1.,
|
||||||
|
@ -89,6 +89,7 @@ def run_test(
|
|||||||
|
|
||||||
# max_model_len should be greater than image_feature_size
|
# max_model_len should be greater than image_feature_size
|
||||||
with vllm_runner(model,
|
with vllm_runner(model,
|
||||||
|
task="generate",
|
||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
max_num_seqs=2,
|
max_num_seqs=2,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
@ -28,6 +28,7 @@ def test_models(
|
|||||||
# if we run HF first, the cuda initialization will be done and it
|
# if we run HF first, the cuda initialization will be done and it
|
||||||
# will hurt multiprocessing backend with fork method (the default method).
|
# will hurt multiprocessing backend with fork method (the default method).
|
||||||
with vllm_runner(model,
|
with vllm_runner(model,
|
||||||
|
task="embedding",
|
||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
max_num_seqs=2,
|
max_num_seqs=2,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Sequence, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig, TaskOption
|
||||||
from vllm.inputs import InputContext
|
from vllm.inputs import InputContext
|
||||||
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
|
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
|
||||||
from vllm.utils import is_cpu
|
from vllm.utils import is_cpu
|
||||||
@ -248,6 +248,7 @@ def check_logprobs_close(
|
|||||||
|
|
||||||
|
|
||||||
def build_model_context(model_name: str,
|
def build_model_context(model_name: str,
|
||||||
|
task: TaskOption = "auto",
|
||||||
tokenizer_name: Optional[str] = None,
|
tokenizer_name: Optional[str] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
dtype: Optional[Union[str, torch.dtype]] = None,
|
dtype: Optional[Union[str, torch.dtype]] = None,
|
||||||
@ -273,7 +274,8 @@ def build_model_context(model_name: str,
|
|||||||
|
|
||||||
model_config = ModelConfig(
|
model_config = ModelConfig(
|
||||||
model_name,
|
model_name,
|
||||||
tokenizer_name,
|
task=task,
|
||||||
|
tokenizer=tokenizer_name,
|
||||||
tokenizer_mode="auto",
|
tokenizer_mode="auto",
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
@ -24,6 +24,7 @@ def test_clip_image_processor(image_assets, mm_registry, dtype, size_factor):
|
|||||||
|
|
||||||
model_config = ModelConfig(
|
model_config = ModelConfig(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
|
task="auto",
|
||||||
tokenizer=MODEL_NAME,
|
tokenizer=MODEL_NAME,
|
||||||
tokenizer_mode="auto",
|
tokenizer_mode="auto",
|
||||||
trust_remote_code=False,
|
trust_remote_code=False,
|
||||||
@ -67,6 +68,7 @@ def test_llava_next_image_processor(image_assets, mm_registry, dtype,
|
|||||||
|
|
||||||
model_config = ModelConfig(
|
model_config = ModelConfig(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
|
task="auto",
|
||||||
tokenizer=MODEL_NAME,
|
tokenizer=MODEL_NAME,
|
||||||
tokenizer_mode="auto",
|
tokenizer_mode="auto",
|
||||||
trust_remote_code=False,
|
trust_remote_code=False,
|
||||||
@ -109,6 +111,7 @@ def test_mm_limits(image_assets, mm_registry, num_images, limit, is_valid):
|
|||||||
|
|
||||||
model_config = ModelConfig(
|
model_config = ModelConfig(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
|
task="auto",
|
||||||
tokenizer=MODEL_NAME,
|
tokenizer=MODEL_NAME,
|
||||||
tokenizer_mode="auto",
|
tokenizer_mode="auto",
|
||||||
trust_remote_code=False,
|
trust_remote_code=False,
|
||||||
@ -139,6 +142,7 @@ def test_image_mapper_multi(image_assets, mm_registry, num_images):
|
|||||||
|
|
||||||
model_config = ModelConfig(
|
model_config = ModelConfig(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
|
task="auto",
|
||||||
tokenizer=MODEL_NAME,
|
tokenizer=MODEL_NAME,
|
||||||
tokenizer_mode="auto",
|
tokenizer_mode="auto",
|
||||||
trust_remote_code=False,
|
trust_remote_code=False,
|
||||||
|
@ -221,6 +221,7 @@ def test_max_tokens_kwarg_overrides(num_crops):
|
|||||||
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
|
expected_seq_count = DEFAULT_NUM_CROPS if num_crops is None else num_crops
|
||||||
|
|
||||||
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
||||||
|
task="generate",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
mm_processor_kwargs=mm_processor_kwargs,
|
mm_processor_kwargs=mm_processor_kwargs,
|
||||||
limit_mm_per_prompt={"image": 1})
|
limit_mm_per_prompt={"image": 1})
|
||||||
@ -256,6 +257,7 @@ def test_max_tokens_kwarg_overrides(num_crops):
|
|||||||
def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs):
|
def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs):
|
||||||
"""Ensure that max token calcs filters out invalid mm_processor_kwargs"""
|
"""Ensure that max token calcs filters out invalid mm_processor_kwargs"""
|
||||||
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
||||||
|
task="generate",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
mm_processor_kwargs=mm_processor_kwargs,
|
mm_processor_kwargs=mm_processor_kwargs,
|
||||||
limit_mm_per_prompt={"image": 1})
|
limit_mm_per_prompt={"image": 1})
|
||||||
@ -278,12 +280,13 @@ def test_max_tokens_with_sad_kwarg_overrides(mm_processor_kwargs):
|
|||||||
|
|
||||||
### Test overrides for the mapper
|
### Test overrides for the mapper
|
||||||
@pytest.mark.parametrize("num_crops", [DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE])
|
@pytest.mark.parametrize("num_crops", [DEFAULT_NUM_CROPS, NUM_CROPS_OVERRIDE])
|
||||||
def test_default_mapper_with_processer_kwargs(image_assets, num_crops):
|
def test_default_mapper_with_processor_kwargs(image_assets, num_crops):
|
||||||
"""Ensure that the mapper processor kwargs can fall back to HF models."""
|
"""Ensure that the mapper processor kwargs can fall back to HF models."""
|
||||||
# NOTE - we don't validate bad inputs for the default mapper, because it's
|
# NOTE - we don't validate bad inputs for the default mapper, because it's
|
||||||
# through the automodel interface in transformers, so we can't easily
|
# through the automodel interface in transformers, so we can't easily
|
||||||
# inspect what kwargs are or are not allowed.
|
# inspect what kwargs are or are not allowed.
|
||||||
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
||||||
|
task="generate",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
mm_processor_kwargs={"num_crops": num_crops},
|
mm_processor_kwargs={"num_crops": num_crops},
|
||||||
limit_mm_per_prompt={"image": 1})
|
limit_mm_per_prompt={"image": 1})
|
||||||
@ -311,6 +314,7 @@ def test_custom_mapper_kwarg_overrides(image_assets, init_num_crops,
|
|||||||
init_num_crops, inference_num_crops)
|
init_num_crops, inference_num_crops)
|
||||||
|
|
||||||
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
||||||
|
task="generate",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
mm_processor_kwargs=init_kwargs,
|
mm_processor_kwargs=init_kwargs,
|
||||||
limit_mm_per_prompt={"image": 1})
|
limit_mm_per_prompt={"image": 1})
|
||||||
@ -348,6 +352,7 @@ def test_custom_mapper_with_sad_kwarg_overrides(image_assets,
|
|||||||
"""Ensure that custom mappers filters out invalid mm_processor_kwargs"""
|
"""Ensure that custom mappers filters out invalid mm_processor_kwargs"""
|
||||||
# Should filter out the init time kwargs
|
# Should filter out the init time kwargs
|
||||||
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
ctx = build_model_context(MULTIMODAL_MODEL_ID,
|
||||||
|
task="generate",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
mm_processor_kwargs=mm_processor_kwargs,
|
mm_processor_kwargs=mm_processor_kwargs,
|
||||||
limit_mm_per_prompt={"image": 1})
|
limit_mm_per_prompt={"image": 1})
|
||||||
|
@ -57,7 +57,8 @@ def test_auto_gptq(model_arg_exptype: Tuple[str, None, str]) -> None:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
model_config = ModelConfig(model_path,
|
model_config = ModelConfig(model_path,
|
||||||
model_path,
|
task="auto",
|
||||||
|
tokenizer=model_path,
|
||||||
tokenizer_mode="auto",
|
tokenizer_mode="auto",
|
||||||
trust_remote_code=False,
|
trust_remote_code=False,
|
||||||
seed=0,
|
seed=0,
|
||||||
|
@ -2,6 +2,42 @@ import pytest
|
|||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(("model_id", "expected_task"), [
|
||||||
|
("facebook/opt-125m", "generate"),
|
||||||
|
("intfloat/e5-mistral-7b-instruct", "embedding"),
|
||||||
|
])
|
||||||
|
def test_auto_task(model_id, expected_task):
|
||||||
|
config = ModelConfig(
|
||||||
|
model_id,
|
||||||
|
task="auto",
|
||||||
|
tokenizer=model_id,
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=False,
|
||||||
|
seed=0,
|
||||||
|
dtype="float16",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.task == expected_task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(("model_id", "bad_task"), [
|
||||||
|
("facebook/opt-125m", "embedding"),
|
||||||
|
("intfloat/e5-mistral-7b-instruct", "generate"),
|
||||||
|
])
|
||||||
|
def test_incorrect_task(model_id, bad_task):
|
||||||
|
with pytest.raises(ValueError, match=r"does not support the .* task"):
|
||||||
|
ModelConfig(
|
||||||
|
model_id,
|
||||||
|
task=bad_task,
|
||||||
|
tokenizer=model_id,
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=False,
|
||||||
|
seed=0,
|
||||||
|
dtype="float16",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
MODEL_IDS_EXPECTED = [
|
MODEL_IDS_EXPECTED = [
|
||||||
("Qwen/Qwen1.5-7B", 32768),
|
("Qwen/Qwen1.5-7B", 32768),
|
||||||
("mistralai/Mistral-7B-v0.1", 4096),
|
("mistralai/Mistral-7B-v0.1", 4096),
|
||||||
@ -14,7 +50,8 @@ def test_disable_sliding_window(model_id_expected):
|
|||||||
model_id, expected = model_id_expected
|
model_id, expected = model_id_expected
|
||||||
model_config = ModelConfig(
|
model_config = ModelConfig(
|
||||||
model_id,
|
model_id,
|
||||||
model_id,
|
task="auto",
|
||||||
|
tokenizer=model_id,
|
||||||
tokenizer_mode="auto",
|
tokenizer_mode="auto",
|
||||||
trust_remote_code=False,
|
trust_remote_code=False,
|
||||||
seed=0,
|
seed=0,
|
||||||
@ -32,7 +69,8 @@ def test_get_sliding_window():
|
|||||||
# when use_sliding_window is False.
|
# when use_sliding_window is False.
|
||||||
qwen2_model_config = ModelConfig(
|
qwen2_model_config = ModelConfig(
|
||||||
"Qwen/Qwen1.5-7B",
|
"Qwen/Qwen1.5-7B",
|
||||||
"Qwen/Qwen1.5-7B",
|
task="auto",
|
||||||
|
tokenizer="Qwen/Qwen1.5-7B",
|
||||||
tokenizer_mode="auto",
|
tokenizer_mode="auto",
|
||||||
trust_remote_code=False,
|
trust_remote_code=False,
|
||||||
seed=0,
|
seed=0,
|
||||||
@ -49,7 +87,8 @@ def test_get_sliding_window():
|
|||||||
|
|
||||||
mistral_model_config = ModelConfig(
|
mistral_model_config = ModelConfig(
|
||||||
"mistralai/Mistral-7B-v0.1",
|
"mistralai/Mistral-7B-v0.1",
|
||||||
"mistralai/Mistral-7B-v0.1",
|
task="auto",
|
||||||
|
tokenizer="mistralai/Mistral-7B-v0.1",
|
||||||
tokenizer_mode="auto",
|
tokenizer_mode="auto",
|
||||||
trust_remote_code=False,
|
trust_remote_code=False,
|
||||||
seed=0,
|
seed=0,
|
||||||
@ -70,7 +109,8 @@ def test_rope_customization():
|
|||||||
|
|
||||||
llama_model_config = ModelConfig(
|
llama_model_config = ModelConfig(
|
||||||
"meta-llama/Meta-Llama-3-8B-Instruct",
|
"meta-llama/Meta-Llama-3-8B-Instruct",
|
||||||
"meta-llama/Meta-Llama-3-8B-Instruct",
|
task="auto",
|
||||||
|
tokenizer="meta-llama/Meta-Llama-3-8B-Instruct",
|
||||||
tokenizer_mode="auto",
|
tokenizer_mode="auto",
|
||||||
trust_remote_code=False,
|
trust_remote_code=False,
|
||||||
dtype="float16",
|
dtype="float16",
|
||||||
@ -82,7 +122,8 @@ def test_rope_customization():
|
|||||||
|
|
||||||
llama_model_config = ModelConfig(
|
llama_model_config = ModelConfig(
|
||||||
"meta-llama/Meta-Llama-3-8B-Instruct",
|
"meta-llama/Meta-Llama-3-8B-Instruct",
|
||||||
"meta-llama/Meta-Llama-3-8B-Instruct",
|
task="auto",
|
||||||
|
tokenizer="meta-llama/Meta-Llama-3-8B-Instruct",
|
||||||
tokenizer_mode="auto",
|
tokenizer_mode="auto",
|
||||||
trust_remote_code=False,
|
trust_remote_code=False,
|
||||||
dtype="float16",
|
dtype="float16",
|
||||||
@ -98,7 +139,8 @@ def test_rope_customization():
|
|||||||
|
|
||||||
longchat_model_config = ModelConfig(
|
longchat_model_config = ModelConfig(
|
||||||
"lmsys/longchat-13b-16k",
|
"lmsys/longchat-13b-16k",
|
||||||
"lmsys/longchat-13b-16k",
|
task="auto",
|
||||||
|
tokenizer="lmsys/longchat-13b-16k",
|
||||||
tokenizer_mode="auto",
|
tokenizer_mode="auto",
|
||||||
trust_remote_code=False,
|
trust_remote_code=False,
|
||||||
dtype="float16",
|
dtype="float16",
|
||||||
@ -112,7 +154,8 @@ def test_rope_customization():
|
|||||||
|
|
||||||
longchat_model_config = ModelConfig(
|
longchat_model_config = ModelConfig(
|
||||||
"lmsys/longchat-13b-16k",
|
"lmsys/longchat-13b-16k",
|
||||||
"lmsys/longchat-13b-16k",
|
task="auto",
|
||||||
|
tokenizer="lmsys/longchat-13b-16k",
|
||||||
tokenizer_mode="auto",
|
tokenizer_mode="auto",
|
||||||
trust_remote_code=False,
|
trust_remote_code=False,
|
||||||
dtype="float16",
|
dtype="float16",
|
||||||
|
@ -59,7 +59,7 @@ def test_deprecate_kwargs_always():
|
|||||||
with pytest.warns(DeprecationWarning, match="'old_arg'"):
|
with pytest.warns(DeprecationWarning, match="'old_arg'"):
|
||||||
dummy(old_arg=1)
|
dummy(old_arg=1)
|
||||||
|
|
||||||
with error_on_warning():
|
with error_on_warning(DeprecationWarning):
|
||||||
dummy(new_arg=1)
|
dummy(new_arg=1)
|
||||||
|
|
||||||
|
|
||||||
@ -69,10 +69,10 @@ def test_deprecate_kwargs_never():
|
|||||||
def dummy(*, old_arg: object = None, new_arg: object = None):
|
def dummy(*, old_arg: object = None, new_arg: object = None):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
with error_on_warning():
|
with error_on_warning(DeprecationWarning):
|
||||||
dummy(old_arg=1)
|
dummy(old_arg=1)
|
||||||
|
|
||||||
with error_on_warning():
|
with error_on_warning(DeprecationWarning):
|
||||||
dummy(new_arg=1)
|
dummy(new_arg=1)
|
||||||
|
|
||||||
|
|
||||||
@ -86,15 +86,15 @@ def test_deprecate_kwargs_dynamic():
|
|||||||
with pytest.warns(DeprecationWarning, match="'old_arg'"):
|
with pytest.warns(DeprecationWarning, match="'old_arg'"):
|
||||||
dummy(old_arg=1)
|
dummy(old_arg=1)
|
||||||
|
|
||||||
with error_on_warning():
|
with error_on_warning(DeprecationWarning):
|
||||||
dummy(new_arg=1)
|
dummy(new_arg=1)
|
||||||
|
|
||||||
is_deprecated = False
|
is_deprecated = False
|
||||||
|
|
||||||
with error_on_warning():
|
with error_on_warning(DeprecationWarning):
|
||||||
dummy(old_arg=1)
|
dummy(old_arg=1)
|
||||||
|
|
||||||
with error_on_warning():
|
with error_on_warning(DeprecationWarning):
|
||||||
dummy(new_arg=1)
|
dummy(new_arg=1)
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ import time
|
|||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, List, Literal, Optional, Union
|
from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
import pytest
|
import pytest
|
||||||
@ -454,13 +454,13 @@ def multi_process_parallel(
|
|||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def error_on_warning():
|
def error_on_warning(category: Type[Warning] = Warning):
|
||||||
"""
|
"""
|
||||||
Within the scope of this context manager, tests will fail if any warning
|
Within the scope of this context manager, tests will fail if any warning
|
||||||
is emitted.
|
of the given category is emitted.
|
||||||
"""
|
"""
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("error")
|
warnings.filterwarnings("error", category=category)
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import enum
|
import enum
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass, field, fields
|
from dataclasses import dataclass, field, fields
|
||||||
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Mapping,
|
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Final, List, Literal,
|
||||||
Optional, Tuple, Type, Union)
|
Mapping, Optional, Set, Tuple, Type, Union)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
@ -33,6 +33,9 @@ logger = init_logger(__name__)
|
|||||||
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
|
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
|
||||||
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
|
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
|
||||||
|
|
||||||
|
Task = Literal["generate", "embedding"]
|
||||||
|
TaskOption = Literal["auto", Task]
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig:
|
class ModelConfig:
|
||||||
"""Configuration for the model.
|
"""Configuration for the model.
|
||||||
@ -40,7 +43,11 @@ class ModelConfig:
|
|||||||
Args:
|
Args:
|
||||||
model: Name or path of the huggingface model to use.
|
model: Name or path of the huggingface model to use.
|
||||||
It is also used as the content for `model_name` tag in metrics
|
It is also used as the content for `model_name` tag in metrics
|
||||||
output when `served_model_name` is not specified.
|
output when `served_model_name` is not specified.
|
||||||
|
task: The task to use the model for. Each vLLM instance only supports
|
||||||
|
one task, even if the same model can be used for multiple tasks.
|
||||||
|
When the model only supports one task, "auto" can be used to select
|
||||||
|
it; otherwise, you must specify explicitly which task to use.
|
||||||
tokenizer: Name or path of the huggingface tokenizer to use.
|
tokenizer: Name or path of the huggingface tokenizer to use.
|
||||||
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
|
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
|
||||||
available, "slow" will always use the slow tokenizer, and
|
available, "slow" will always use the slow tokenizer, and
|
||||||
@ -108,6 +115,7 @@ class ModelConfig:
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model: str,
|
model: str,
|
||||||
|
task: TaskOption,
|
||||||
tokenizer: str,
|
tokenizer: str,
|
||||||
tokenizer_mode: str,
|
tokenizer_mode: str,
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
@ -207,7 +215,11 @@ class ModelConfig:
|
|||||||
|
|
||||||
self.override_neuron_config = override_neuron_config if is_neuron(
|
self.override_neuron_config = override_neuron_config if is_neuron(
|
||||||
) else None
|
) else None
|
||||||
self._verify_embedding_mode()
|
|
||||||
|
supported_tasks, task = self._resolve_task(task, self.hf_config)
|
||||||
|
self.supported_tasks = supported_tasks
|
||||||
|
self.task: Final = task
|
||||||
|
|
||||||
self._verify_quantization()
|
self._verify_quantization()
|
||||||
self._verify_cuda_graph()
|
self._verify_cuda_graph()
|
||||||
self._verify_bnb_config()
|
self._verify_bnb_config()
|
||||||
@ -241,18 +253,41 @@ class ModelConfig:
|
|||||||
"either 'auto', 'slow' or 'mistral'.")
|
"either 'auto', 'slow' or 'mistral'.")
|
||||||
self.tokenizer_mode = tokenizer_mode
|
self.tokenizer_mode = tokenizer_mode
|
||||||
|
|
||||||
def _verify_embedding_mode(self) -> None:
|
def _resolve_task(
|
||||||
architectures = getattr(self.hf_config, "architectures", [])
|
self,
|
||||||
|
task_option: TaskOption,
|
||||||
|
hf_config: PretrainedConfig,
|
||||||
|
) -> Tuple[Set[Task], Task]:
|
||||||
|
architectures = getattr(hf_config, "architectures", [])
|
||||||
|
|
||||||
# TODO: Allow the same model architecture to be specified as either
|
task_support: Dict[Task, bool] = {
|
||||||
# generation or embedding model
|
# NOTE: Listed from highest to lowest priority,
|
||||||
if "Phi3VForCausalLM" in architectures:
|
# in case the model supports multiple of them
|
||||||
# Match both remote and local names
|
"generate": ModelRegistry.is_text_generation_model(architectures),
|
||||||
embedding_mode = "/VLM2Vec" in self.model
|
"embedding": ModelRegistry.is_embedding_model(architectures),
|
||||||
|
}
|
||||||
|
supported_tasks_lst: List[Task] = [
|
||||||
|
task for task, is_supported in task_support.items() if is_supported
|
||||||
|
]
|
||||||
|
supported_tasks = set(supported_tasks_lst)
|
||||||
|
|
||||||
|
if task_option == "auto":
|
||||||
|
selected_task = next(iter(supported_tasks_lst))
|
||||||
|
|
||||||
|
if len(supported_tasks) > 1:
|
||||||
|
logger.info(
|
||||||
|
"This model supports multiple tasks: %s. "
|
||||||
|
"Defaulting to '%s'.", supported_tasks, selected_task)
|
||||||
else:
|
else:
|
||||||
embedding_mode = ModelRegistry.is_embedding_model(architectures)
|
if task_option not in supported_tasks:
|
||||||
|
msg = (
|
||||||
|
f"This model does not support the '{task_option}' task. "
|
||||||
|
f"Supported tasks: {supported_tasks}")
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
self.embedding_mode = embedding_mode
|
selected_task = task_option
|
||||||
|
|
||||||
|
return supported_tasks, selected_task
|
||||||
|
|
||||||
def _parse_quant_hf_config(self):
|
def _parse_quant_hf_config(self):
|
||||||
quant_cfg = getattr(self.hf_config, "quantization_config", None)
|
quant_cfg = getattr(self.hf_config, "quantization_config", None)
|
||||||
@ -401,7 +436,7 @@ class ModelConfig:
|
|||||||
|
|
||||||
# Async postprocessor is not necessary with embedding mode
|
# Async postprocessor is not necessary with embedding mode
|
||||||
# since there is no token generation
|
# since there is no token generation
|
||||||
if self.embedding_mode:
|
if self.task == "embedding":
|
||||||
self.use_async_output_proc = False
|
self.use_async_output_proc = False
|
||||||
|
|
||||||
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
|
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
|
||||||
@ -582,11 +617,6 @@ class ModelConfig:
|
|||||||
(hasattr(self.hf_config, "text_config") and getattr(
|
(hasattr(self.hf_config, "text_config") and getattr(
|
||||||
self.hf_config.text_config, "is_encoder_decoder", False)))
|
self.hf_config.text_config, "is_encoder_decoder", False)))
|
||||||
|
|
||||||
@property
|
|
||||||
def is_embedding_model(self) -> bool:
|
|
||||||
"""Extract the embedding model flag."""
|
|
||||||
return self.embedding_mode
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_multimodal_model(self) -> bool:
|
def is_multimodal_model(self) -> bool:
|
||||||
return self.multimodal_config is not None
|
return self.multimodal_config is not None
|
||||||
@ -943,6 +973,7 @@ class SchedulerConfig:
|
|||||||
"""Scheduler configuration.
|
"""Scheduler configuration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
task: The task to use the model for.
|
||||||
max_num_batched_tokens: Maximum number of tokens to be processed in
|
max_num_batched_tokens: Maximum number of tokens to be processed in
|
||||||
a single iteration.
|
a single iteration.
|
||||||
max_num_seqs: Maximum number of sequences to be processed in a single
|
max_num_seqs: Maximum number of sequences to be processed in a single
|
||||||
@ -957,7 +988,6 @@ class SchedulerConfig:
|
|||||||
prompt latency) before scheduling next prompt.
|
prompt latency) before scheduling next prompt.
|
||||||
enable_chunked_prefill: If True, prefill requests can be chunked based
|
enable_chunked_prefill: If True, prefill requests can be chunked based
|
||||||
on the remaining max_num_batched_tokens.
|
on the remaining max_num_batched_tokens.
|
||||||
embedding_mode: Whether the running model is for embedding.
|
|
||||||
preemption_mode: Whether to perform preemption by swapping or
|
preemption_mode: Whether to perform preemption by swapping or
|
||||||
recomputation. If not specified, we determine the mode as follows:
|
recomputation. If not specified, we determine the mode as follows:
|
||||||
We use recomputation by default since it incurs lower overhead than
|
We use recomputation by default since it incurs lower overhead than
|
||||||
@ -972,13 +1002,13 @@ class SchedulerConfig:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
task: Task,
|
||||||
max_num_batched_tokens: Optional[int],
|
max_num_batched_tokens: Optional[int],
|
||||||
max_num_seqs: int,
|
max_num_seqs: int,
|
||||||
max_model_len: int,
|
max_model_len: int,
|
||||||
num_lookahead_slots: int = 0,
|
num_lookahead_slots: int = 0,
|
||||||
delay_factor: float = 0.0,
|
delay_factor: float = 0.0,
|
||||||
enable_chunked_prefill: bool = False,
|
enable_chunked_prefill: bool = False,
|
||||||
embedding_mode: bool = False,
|
|
||||||
is_multimodal_model: bool = False,
|
is_multimodal_model: bool = False,
|
||||||
preemption_mode: Optional[str] = None,
|
preemption_mode: Optional[str] = None,
|
||||||
num_scheduler_steps: int = 1,
|
num_scheduler_steps: int = 1,
|
||||||
@ -1002,7 +1032,7 @@ class SchedulerConfig:
|
|||||||
# for higher throughput.
|
# for higher throughput.
|
||||||
max_num_batched_tokens = max(max_model_len, 2048)
|
max_num_batched_tokens = max(max_model_len, 2048)
|
||||||
|
|
||||||
if embedding_mode:
|
if task == "embedding":
|
||||||
# For embedding, choose specific value for higher throughput
|
# For embedding, choose specific value for higher throughput
|
||||||
max_num_batched_tokens = max(
|
max_num_batched_tokens = max(
|
||||||
max_num_batched_tokens,
|
max_num_batched_tokens,
|
||||||
@ -1022,12 +1052,12 @@ class SchedulerConfig:
|
|||||||
"Chunked prefill is enabled with max_num_batched_tokens=%d.",
|
"Chunked prefill is enabled with max_num_batched_tokens=%d.",
|
||||||
self.max_num_batched_tokens)
|
self.max_num_batched_tokens)
|
||||||
|
|
||||||
|
self.task: Final = task
|
||||||
self.max_num_seqs = max_num_seqs
|
self.max_num_seqs = max_num_seqs
|
||||||
self.max_model_len = max_model_len
|
self.max_model_len = max_model_len
|
||||||
self.num_lookahead_slots = num_lookahead_slots
|
self.num_lookahead_slots = num_lookahead_slots
|
||||||
self.delay_factor = delay_factor
|
self.delay_factor = delay_factor
|
||||||
self.chunked_prefill_enabled = enable_chunked_prefill
|
self.chunked_prefill_enabled = enable_chunked_prefill
|
||||||
self.embedding_mode = embedding_mode
|
|
||||||
self.preemption_mode = preemption_mode
|
self.preemption_mode = preemption_mode
|
||||||
self.num_scheduler_steps = num_scheduler_steps
|
self.num_scheduler_steps = num_scheduler_steps
|
||||||
self.multi_step_stream_outputs = multi_step_stream_outputs
|
self.multi_step_stream_outputs = multi_step_stream_outputs
|
||||||
@ -1239,6 +1269,7 @@ class SpeculativeConfig:
|
|||||||
ngram_prompt_lookup_min = 0
|
ngram_prompt_lookup_min = 0
|
||||||
draft_model_config = ModelConfig(
|
draft_model_config = ModelConfig(
|
||||||
model=speculative_model,
|
model=speculative_model,
|
||||||
|
task=target_model_config.task,
|
||||||
tokenizer=target_model_config.tokenizer,
|
tokenizer=target_model_config.tokenizer,
|
||||||
tokenizer_mode=target_model_config.tokenizer_mode,
|
tokenizer_mode=target_model_config.tokenizer_mode,
|
||||||
trust_remote_code=target_model_config.trust_remote_code,
|
trust_remote_code=target_model_config.trust_remote_code,
|
||||||
|
@ -313,7 +313,7 @@ class Scheduler:
|
|||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
|
|
||||||
version = "selfattn"
|
version = "selfattn"
|
||||||
if (self.scheduler_config.embedding_mode
|
if (self.scheduler_config.task == "embedding"
|
||||||
or self.cache_config.is_attention_free):
|
or self.cache_config.is_attention_free):
|
||||||
version = "placeholder"
|
version = "placeholder"
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ import dataclasses
|
|||||||
import json
|
import json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
|
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
|
||||||
Tuple, Type, Union, cast)
|
Tuple, Type, Union, cast, get_args)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -12,7 +12,7 @@ from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
|
|||||||
DeviceConfig, EngineConfig, LoadConfig, LoadFormat,
|
DeviceConfig, EngineConfig, LoadConfig, LoadFormat,
|
||||||
LoRAConfig, ModelConfig, ObservabilityConfig,
|
LoRAConfig, ModelConfig, ObservabilityConfig,
|
||||||
ParallelConfig, PromptAdapterConfig, SchedulerConfig,
|
ParallelConfig, PromptAdapterConfig, SchedulerConfig,
|
||||||
SpeculativeConfig, TokenizerPoolConfig)
|
SpeculativeConfig, TaskOption, TokenizerPoolConfig)
|
||||||
from vllm.executor.executor_base import ExecutorBase
|
from vllm.executor.executor_base import ExecutorBase
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
@ -84,6 +84,7 @@ class EngineArgs:
|
|||||||
model: str = 'facebook/opt-125m'
|
model: str = 'facebook/opt-125m'
|
||||||
served_model_name: Optional[Union[str, List[str]]] = None
|
served_model_name: Optional[Union[str, List[str]]] = None
|
||||||
tokenizer: Optional[str] = None
|
tokenizer: Optional[str] = None
|
||||||
|
task: TaskOption = "auto"
|
||||||
skip_tokenizer_init: bool = False
|
skip_tokenizer_init: bool = False
|
||||||
tokenizer_mode: str = 'auto'
|
tokenizer_mode: str = 'auto'
|
||||||
trust_remote_code: bool = False
|
trust_remote_code: bool = False
|
||||||
@ -198,6 +199,15 @@ class EngineArgs:
|
|||||||
type=str,
|
type=str,
|
||||||
default=EngineArgs.model,
|
default=EngineArgs.model,
|
||||||
help='Name or path of the huggingface model to use.')
|
help='Name or path of the huggingface model to use.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--task',
|
||||||
|
default=EngineArgs.task,
|
||||||
|
choices=get_args(TaskOption),
|
||||||
|
help='The task to use the model for. Each vLLM instance only '
|
||||||
|
'supports one task, even if the same model can be used for '
|
||||||
|
'multiple tasks. When the model only supports one task, "auto" '
|
||||||
|
'can be used to select it; otherwise, you must specify explicitly '
|
||||||
|
'which task to use.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--tokenizer',
|
'--tokenizer',
|
||||||
type=nullable_str,
|
type=nullable_str,
|
||||||
@ -838,6 +848,7 @@ class EngineArgs:
|
|||||||
def create_model_config(self) -> ModelConfig:
|
def create_model_config(self) -> ModelConfig:
|
||||||
return ModelConfig(
|
return ModelConfig(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
|
task=self.task,
|
||||||
# We know this is not None because we set it in __post_init__
|
# We know this is not None because we set it in __post_init__
|
||||||
tokenizer=cast(str, self.tokenizer),
|
tokenizer=cast(str, self.tokenizer),
|
||||||
tokenizer_mode=self.tokenizer_mode,
|
tokenizer_mode=self.tokenizer_mode,
|
||||||
@ -1026,13 +1037,13 @@ class EngineArgs:
|
|||||||
" please file an issue with detailed information.")
|
" please file an issue with detailed information.")
|
||||||
|
|
||||||
scheduler_config = SchedulerConfig(
|
scheduler_config = SchedulerConfig(
|
||||||
|
task=model_config.task,
|
||||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||||
max_num_seqs=self.max_num_seqs,
|
max_num_seqs=self.max_num_seqs,
|
||||||
max_model_len=model_config.max_model_len,
|
max_model_len=model_config.max_model_len,
|
||||||
num_lookahead_slots=num_lookahead_slots,
|
num_lookahead_slots=num_lookahead_slots,
|
||||||
delay_factor=self.scheduler_delay_factor,
|
delay_factor=self.scheduler_delay_factor,
|
||||||
enable_chunked_prefill=self.enable_chunked_prefill,
|
enable_chunked_prefill=self.enable_chunked_prefill,
|
||||||
embedding_mode=model_config.embedding_mode,
|
|
||||||
is_multimodal_model=model_config.is_multimodal_model,
|
is_multimodal_model=model_config.is_multimodal_model,
|
||||||
preemption_mode=self.preemption_mode,
|
preemption_mode=self.preemption_mode,
|
||||||
num_scheduler_steps=self.num_scheduler_steps,
|
num_scheduler_steps=self.num_scheduler_steps,
|
||||||
|
@ -344,7 +344,7 @@ class LLMEngine:
|
|||||||
observability_config=self.observability_config,
|
observability_config=self.observability_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not self.model_config.embedding_mode:
|
if self.model_config.task != "embedding":
|
||||||
self._initialize_kv_caches()
|
self._initialize_kv_caches()
|
||||||
|
|
||||||
# If usage stat is enabled, collect relevant info.
|
# If usage stat is enabled, collect relevant info.
|
||||||
@ -1116,7 +1116,7 @@ class LLMEngine:
|
|||||||
seq_group.metrics.model_execute_time = (
|
seq_group.metrics.model_execute_time = (
|
||||||
o.model_execute_time)
|
o.model_execute_time)
|
||||||
|
|
||||||
if self.model_config.embedding_mode:
|
if self.model_config.task == "embedding":
|
||||||
self._process_sequence_group_outputs(seq_group, output)
|
self._process_sequence_group_outputs(seq_group, output)
|
||||||
else:
|
else:
|
||||||
self.output_processor.process_prompt_logprob(seq_group, output)
|
self.output_processor.process_prompt_logprob(seq_group, output)
|
||||||
@ -1855,9 +1855,6 @@ class LLMEngine:
|
|||||||
def is_encoder_decoder_model(self):
|
def is_encoder_decoder_model(self):
|
||||||
return self.input_preprocessor.is_encoder_decoder_model()
|
return self.input_preprocessor.is_encoder_decoder_model()
|
||||||
|
|
||||||
def is_embedding_model(self):
|
|
||||||
return self.model_config.is_embedding_model
|
|
||||||
|
|
||||||
def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
|
def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
|
||||||
EncoderDecoderInputs]):
|
EncoderDecoderInputs]):
|
||||||
if self.model_config.is_multimodal_model:
|
if self.model_config.is_multimodal_model:
|
||||||
|
@ -8,7 +8,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
|
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
|
||||||
BeamSearchSequence, get_beam_search_score)
|
BeamSearchSequence, get_beam_search_score)
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs, TaskOption
|
||||||
from vllm.engine.llm_engine import LLMEngine
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||||
apply_hf_chat_template,
|
apply_hf_chat_template,
|
||||||
@ -29,7 +29,7 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
|||||||
get_cached_tokenizer)
|
get_cached_tokenizer)
|
||||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import Counter, deprecate_kwargs, is_list_of
|
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -108,6 +108,12 @@ class LLM:
|
|||||||
DEPRECATE_LEGACY: ClassVar[bool] = False
|
DEPRECATE_LEGACY: ClassVar[bool] = False
|
||||||
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
|
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
|
||||||
|
|
||||||
|
DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
|
||||||
|
"""
|
||||||
|
A flag to toggle whether to deprecate positional arguments in
|
||||||
|
:meth:`LLM.__init__`.
|
||||||
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def deprecate_legacy_api(cls):
|
def deprecate_legacy_api(cls):
|
||||||
@ -117,6 +123,13 @@ class LLM:
|
|||||||
|
|
||||||
cls.DEPRECATE_LEGACY = False
|
cls.DEPRECATE_LEGACY = False
|
||||||
|
|
||||||
|
@deprecate_args(
|
||||||
|
start_index=2, # Ignore self and model
|
||||||
|
is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
|
||||||
|
additional_message=(
|
||||||
|
"All positional arguments other than `model` will be "
|
||||||
|
"replaced with keyword arguments in an upcoming version."),
|
||||||
|
)
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
@ -139,6 +152,8 @@ class LLM:
|
|||||||
disable_custom_all_reduce: bool = False,
|
disable_custom_all_reduce: bool = False,
|
||||||
disable_async_output_proc: bool = False,
|
disable_async_output_proc: bool = False,
|
||||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
# After positional args are removed, move this right below `model`
|
||||||
|
task: TaskOption = "auto",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
'''
|
'''
|
||||||
@ -153,6 +168,7 @@ class LLM:
|
|||||||
|
|
||||||
engine_args = EngineArgs(
|
engine_args = EngineArgs(
|
||||||
model=model,
|
model=model,
|
||||||
|
task=task,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
tokenizer_mode=tokenizer_mode,
|
tokenizer_mode=tokenizer_mode,
|
||||||
skip_tokenizer_init=skip_tokenizer_init,
|
skip_tokenizer_init=skip_tokenizer_init,
|
||||||
@ -316,10 +332,21 @@ class LLM:
|
|||||||
considered legacy and may be deprecated in the future. You should
|
considered legacy and may be deprecated in the future. You should
|
||||||
instead pass them via the ``inputs`` parameter.
|
instead pass them via the ``inputs`` parameter.
|
||||||
"""
|
"""
|
||||||
if self.llm_engine.model_config.embedding_mode:
|
task = self.llm_engine.model_config.task
|
||||||
raise ValueError(
|
if task != "generate":
|
||||||
|
messages = [
|
||||||
"LLM.generate() is only supported for (conditional) generation "
|
"LLM.generate() is only supported for (conditional) generation "
|
||||||
"models (XForCausalLM, XForConditionalGeneration).")
|
"models (XForCausalLM, XForConditionalGeneration).",
|
||||||
|
]
|
||||||
|
|
||||||
|
supported_tasks = self.llm_engine.model_config.supported_tasks
|
||||||
|
if "generate" in supported_tasks:
|
||||||
|
messages.append(
|
||||||
|
"Your model supports the 'generate' task, but is "
|
||||||
|
f"currently initialized for the '{task}' task. Please "
|
||||||
|
"initialize the model using `--task generate`.")
|
||||||
|
|
||||||
|
raise ValueError(" ".join(messages))
|
||||||
|
|
||||||
if prompt_token_ids is not None:
|
if prompt_token_ids is not None:
|
||||||
parsed_prompts = self._convert_v1_inputs(
|
parsed_prompts = self._convert_v1_inputs(
|
||||||
@ -692,10 +719,18 @@ class LLM:
|
|||||||
considered legacy and may be deprecated in the future. You should
|
considered legacy and may be deprecated in the future. You should
|
||||||
instead pass them via the ``inputs`` parameter.
|
instead pass them via the ``inputs`` parameter.
|
||||||
"""
|
"""
|
||||||
if not self.llm_engine.model_config.embedding_mode:
|
task = self.llm_engine.model_config.task
|
||||||
raise ValueError(
|
if task != "embedding":
|
||||||
"LLM.encode() is only supported for embedding models (XModel)."
|
messages = ["LLM.encode() is only supported for embedding models."]
|
||||||
)
|
|
||||||
|
supported_tasks = self.llm_engine.model_config.supported_tasks
|
||||||
|
if "embedding" in supported_tasks:
|
||||||
|
messages.append(
|
||||||
|
"Your model supports the 'embedding' task, but is "
|
||||||
|
f"currently initialized for the '{task}' task. Please "
|
||||||
|
"initialize the model using `--task embedding`.")
|
||||||
|
|
||||||
|
raise ValueError(" ".join(messages))
|
||||||
|
|
||||||
if prompt_token_ids is not None:
|
if prompt_token_ids is not None:
|
||||||
parsed_prompts = self._convert_v1_inputs(
|
parsed_prompts = self._convert_v1_inputs(
|
||||||
@ -905,6 +940,3 @@ class LLM:
|
|||||||
|
|
||||||
def _is_encoder_decoder_model(self):
|
def _is_encoder_decoder_model(self):
|
||||||
return self.llm_engine.is_encoder_decoder_model()
|
return self.llm_engine.is_encoder_decoder_model()
|
||||||
|
|
||||||
def _is_embedding_model(self):
|
|
||||||
return self.llm_engine.is_embedding_model()
|
|
||||||
|
@ -83,7 +83,8 @@ class OpenAIServingEmbedding(OpenAIServing):
|
|||||||
lora_modules=None,
|
lora_modules=None,
|
||||||
prompt_adapters=None,
|
prompt_adapters=None,
|
||||||
request_logger=request_logger)
|
request_logger=request_logger)
|
||||||
self._enabled = self._check_embedding_mode(model_config.embedding_mode)
|
self._enabled = self._check_embedding_mode(
|
||||||
|
model_config.task == "embedding")
|
||||||
|
|
||||||
async def create_embedding(
|
async def create_embedding(
|
||||||
self,
|
self,
|
||||||
|
@ -1034,10 +1034,54 @@ def identity(value: T) -> T:
|
|||||||
F = TypeVar('F', bound=Callable[..., Any])
|
F = TypeVar('F', bound=Callable[..., Any])
|
||||||
|
|
||||||
|
|
||||||
|
def deprecate_args(
|
||||||
|
start_index: int,
|
||||||
|
is_deprecated: Union[bool, Callable[[], bool]] = True,
|
||||||
|
additional_message: Optional[str] = None,
|
||||||
|
) -> Callable[[F], F]:
|
||||||
|
|
||||||
|
if not callable(is_deprecated):
|
||||||
|
is_deprecated = partial(identity, is_deprecated)
|
||||||
|
|
||||||
|
def wrapper(fn: F) -> F:
|
||||||
|
|
||||||
|
params = inspect.signature(fn).parameters
|
||||||
|
pos_types = (
|
||||||
|
inspect.Parameter.POSITIONAL_ONLY,
|
||||||
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||||
|
)
|
||||||
|
pos_kws = [
|
||||||
|
kw for kw, param in params.items() if param.kind in pos_types
|
||||||
|
]
|
||||||
|
|
||||||
|
@wraps(fn)
|
||||||
|
def inner(*args, **kwargs):
|
||||||
|
if is_deprecated():
|
||||||
|
deprecated_args = pos_kws[start_index:len(args)]
|
||||||
|
if deprecated_args:
|
||||||
|
msg = (
|
||||||
|
f"The positional arguments {deprecated_args} are "
|
||||||
|
"deprecated and will be removed in a future update.")
|
||||||
|
if additional_message is not None:
|
||||||
|
msg += f" {additional_message}"
|
||||||
|
|
||||||
|
warnings.warn(
|
||||||
|
DeprecationWarning(msg),
|
||||||
|
stacklevel=3, # The inner function takes up one level
|
||||||
|
)
|
||||||
|
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return inner # type: ignore
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def deprecate_kwargs(
|
def deprecate_kwargs(
|
||||||
*kws: str,
|
*kws: str,
|
||||||
is_deprecated: Union[bool, Callable[[], bool]] = True,
|
is_deprecated: Union[bool, Callable[[], bool]] = True,
|
||||||
additional_message: Optional[str] = None) -> Callable[[F], F]:
|
additional_message: Optional[str] = None,
|
||||||
|
) -> Callable[[F], F]:
|
||||||
deprecated_kws = set(kws)
|
deprecated_kws = set(kws)
|
||||||
|
|
||||||
if not callable(is_deprecated):
|
if not callable(is_deprecated):
|
||||||
|
@ -92,7 +92,7 @@ class Worker(LocalOrDistributedWorkerBase):
|
|||||||
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
||||||
if model_runner_cls is not None:
|
if model_runner_cls is not None:
|
||||||
ModelRunnerClass = model_runner_cls
|
ModelRunnerClass = model_runner_cls
|
||||||
elif self._is_embedding_model():
|
elif model_config.task == "embedding":
|
||||||
ModelRunnerClass = EmbeddingModelRunner
|
ModelRunnerClass = EmbeddingModelRunner
|
||||||
elif self._is_encoder_decoder_model():
|
elif self._is_encoder_decoder_model():
|
||||||
ModelRunnerClass = EncoderDecoderModelRunner
|
ModelRunnerClass = EncoderDecoderModelRunner
|
||||||
@ -147,9 +147,6 @@ class Worker(LocalOrDistributedWorkerBase):
|
|||||||
def _is_encoder_decoder_model(self):
|
def _is_encoder_decoder_model(self):
|
||||||
return self.model_config.is_encoder_decoder_model
|
return self.model_config.is_encoder_decoder_model
|
||||||
|
|
||||||
def _is_embedding_model(self):
|
|
||||||
return self.model_config.is_embedding_model
|
|
||||||
|
|
||||||
def init_device(self) -> None:
|
def init_device(self) -> None:
|
||||||
if self.device_config.device.type == "cuda":
|
if self.device_config.device.type == "cuda":
|
||||||
# torch.distributed.all_reduce does not free the input tensor until
|
# torch.distributed.all_reduce does not free the input tensor until
|
||||||
|
Loading…
x
Reference in New Issue
Block a user