[Core] Consolidate prompt arguments to LLM engines (#4328)

Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Cyrus Leung 2024-05-29 04:29:31 +08:00 committed by GitHub
parent 290f4ada2b
commit 5ae5ed1e60
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
43 changed files with 1407 additions and 442 deletions

View File

@ -63,9 +63,9 @@ steps:
mirror_hardwares: [amd] mirror_hardwares: [amd]
commands: commands:
# these tests have to be separated, because each one will allocate all posible GPU memory - pytest -v -s test_inputs.py
- pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py - pytest -v -s entrypoints -m llm
- pytest -v -s entrypoints/test_server_oot_registration.py - pytest -v -s entrypoints -m openai
- label: Examples Test - label: Examples Test
working_dir: "/vllm-workspace/examples" working_dir: "/vllm-workspace/examples"
@ -110,6 +110,9 @@ steps:
mirror_hardwares: [amd] mirror_hardwares: [amd]
command: pytest -v -s test_logits_processor.py command: pytest -v -s test_logits_processor.py
- label: Utils Test
command: pytest -v -s test_utils.py
- label: Worker Test - label: Worker Test
mirror_hardwares: [amd] mirror_hardwares: [amd]
command: pytest -v -s worker command: pytest -v -s worker

View File

@ -3,13 +3,14 @@ import argparse
import json import json
import time import time
from pathlib import Path from pathlib import Path
from typing import Optional from typing import List, Optional
import numpy as np import numpy as np
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.inputs import PromptStrictInputs
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
@ -48,7 +49,9 @@ def main(args: argparse.Namespace):
dummy_prompt_token_ids = np.random.randint(10000, dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size, size=(args.batch_size,
args.input_len)) args.input_len))
dummy_prompt_token_ids = dummy_prompt_token_ids.tolist() dummy_inputs: List[PromptStrictInputs] = [{
"prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()]
def run_to_completion(profile_dir: Optional[str] = None): def run_to_completion(profile_dir: Optional[str] = None):
if profile_dir: if profile_dir:
@ -59,13 +62,13 @@ def main(args: argparse.Namespace):
], ],
on_trace_ready=torch.profiler.tensorboard_trace_handler( on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir))) as p: str(profile_dir))) as p:
llm.generate(prompt_token_ids=dummy_prompt_token_ids, llm.generate(dummy_inputs,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=False) use_tqdm=False)
print(p.key_averages()) print(p.key_averages())
else: else:
start_time = time.perf_counter() start_time = time.perf_counter()
llm.generate(prompt_token_ids=dummy_prompt_token_ids, llm.generate(dummy_inputs,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=False) use_tqdm=False)
end_time = time.perf_counter() end_time = time.perf_counter()

View File

@ -1,5 +1,5 @@
LLM Class LLM Class
========== =========
.. autoclass:: vllm.LLM .. autoclass:: vllm.LLM
:members: :members:

View File

@ -0,0 +1,14 @@
LLM Inputs
==========
.. autodata:: vllm.inputs.PromptStrictInputs
.. autoclass:: vllm.inputs.TextPrompt
:show-inheritance:
:members:
:member-order: bysource
.. autoclass:: vllm.inputs.TokensPrompt
:show-inheritance:
:members:
:member-order: bysource

View File

@ -0,0 +1,8 @@
Offline Inference
=================================
.. toctree::
:maxdepth: 1
llm
llm_inputs

View File

@ -68,13 +68,6 @@ Documentation
getting_started/quickstart getting_started/quickstart
getting_started/examples/examples_index getting_started/examples/examples_index
.. toctree::
:maxdepth: 1
:caption: Offline Inference
offline_inference/llm
offline_inference/sampling_params
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1
:caption: Serving :caption: Serving
@ -109,6 +102,8 @@ Documentation
:maxdepth: 2 :maxdepth: 2
:caption: Developer Documentation :caption: Developer Documentation
dev/sampling_params
dev/offline_inference/offline_index
dev/engine/engine_index dev/engine/engine_index
dev/kernel/paged_attention dev/kernel/paged_attention
dev/dockerfile/dockerfile dev/dockerfile/dockerfile

View File

@ -48,7 +48,7 @@ completion = client.chat.completions.create(
``` ```
### Extra Parameters for Chat API ### Extra Parameters for Chat API
The following [sampling parameters (click through to see documentation)](../offline_inference/sampling_params.rst) are supported. The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py ```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python :language: python
@ -65,7 +65,7 @@ The following extra parameters are supported:
``` ```
### Extra Parameters for Completions API ### Extra Parameters for Completions API
The following [sampling parameters (click through to see documentation)](../offline_inference/sampling_params.rst) are supported. The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py ```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python :language: python

View File

@ -23,11 +23,15 @@ def run_llava_pixel_values():
"\nUSER: What is the content of this image?\nASSISTANT:") "\nUSER: What is the content of this image?\nASSISTANT:")
# This should be provided by another online or offline component. # This should be provided by another online or offline component.
images = torch.load("images/stop_sign_pixel_values.pt") image = torch.load("images/stop_sign_pixel_values.pt")
outputs = llm.generate({
"prompt":
prompt,
"multi_modal_data":
MultiModalData(type=MultiModalData.Type.IMAGE, data=image),
})
outputs = llm.generate(prompt,
multi_modal_data=MultiModalData(
type=MultiModalData.Type.IMAGE, data=images))
for o in outputs: for o in outputs:
generated_text = o.outputs[0].text generated_text = o.outputs[0].text
print(generated_text) print(generated_text)
@ -46,11 +50,14 @@ def run_llava_image_features():
"\nUSER: What is the content of this image?\nASSISTANT:") "\nUSER: What is the content of this image?\nASSISTANT:")
# This should be provided by another online or offline component. # This should be provided by another online or offline component.
images = torch.load("images/stop_sign_image_features.pt") image = torch.load("images/stop_sign_image_features.pt")
outputs = llm.generate(prompt, outputs = llm.generate({
multi_modal_data=MultiModalData( "prompt":
type=MultiModalData.Type.IMAGE, data=images)) prompt,
"multi_modal_data":
MultiModalData(type=MultiModalData.Type.IMAGE, data=image),
})
for o in outputs: for o in outputs:
generated_text = o.outputs[0].text generated_text = o.outputs[0].text
print(generated_text) print(generated_text)

View File

@ -65,3 +65,10 @@ skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build"
[tool.isort] [tool.isort]
use_parentheses = true use_parentheses = true
skip_gitignore = true skip_gitignore = true
[tool.pytest.ini_options]
markers = [
"skip_global_cleanup",
"llm: run tests for vLLM API only",
"openai: run tests for OpenAI API only",
]

View File

@ -25,7 +25,7 @@ class MockEngine:
return [RequestOutput( return [RequestOutput(
request_id=self.request_id)] if self.request_id else [] request_id=self.request_id)] if self.request_id else []
async def encode_request_async(self, *args, **kwargs): async def process_model_inputs_async(self, *args, **kwargs):
pass pass
def generate(self, request_id): def generate(self, request_id):

View File

@ -29,7 +29,7 @@ def server():
ray.shutdown() ray.shutdown()
@pytest.fixture(scope="session") @pytest.fixture(scope="module")
def client(): def client():
client = openai.AsyncOpenAI( client = openai.AsyncOpenAI(
base_url="http://localhost:8000/v1", base_url="http://localhost:8000/v1",

View File

@ -12,6 +12,7 @@ from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer,
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
from vllm.distributed import destroy_model_parallel from vllm.distributed import destroy_model_parallel
from vllm.inputs import PromptInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import MultiModalData from vllm.sequence import MultiModalData
@ -402,12 +403,22 @@ class VllmRunner:
) -> List[Tuple[List[int], str]]: ) -> List[Tuple[List[int], str]]:
if images is not None: if images is not None:
assert len(prompts) == images.shape[0] assert len(prompts) == images.shape[0]
req_outputs = self.model.generate(
prompts, prompt_inputs: List[PromptInputs] = []
sampling_params=sampling_params, for i, prompt in enumerate(prompts):
multi_modal_data=MultiModalData(type=MultiModalData.Type.IMAGE, image = None if images is None else images[i:i + 1]
data=images) mm_data = None if image is None else MultiModalData(
if images is not None else None) type=MultiModalData.Type.IMAGE,
data=image,
)
prompt_inputs.append({
"prompt": prompt,
"multi_modal_data": mm_data,
})
req_outputs = self.model.generate(prompt_inputs,
sampling_params=sampling_params)
outputs = [] outputs = []
for req_output in req_outputs: for req_output in req_outputs:
prompt_str = req_output.prompt prompt_str = req_output.prompt

View File

@ -133,8 +133,11 @@ def test_append_slot_cow():
# Allocate prompt to gpu block. There is one slot left in the block. # Allocate prompt to gpu block. There is one slot left in the block.
prompt = Sequence(seq_id=1, prompt = Sequence(seq_id=1,
prompt="one two three", inputs={
prompt_token_ids=[1, 2, 3], "prompt": "one two three",
"prompt_token_ids": [1, 2, 3],
"multi_modal_data": None
},
block_size=block_size) block_size=block_size)
# Fork the sequence, such that a COW will be required when we append a new # Fork the sequence, such that a COW will be required when we append a new
@ -304,7 +307,13 @@ def test_sliding_window_multi_seq():
assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks
parent = Sequence(1, "one two three", [0, 1, 2], block_size) parent = Sequence(seq_id=1,
inputs={
"prompt": "one two three",
"prompt_token_ids": [0, 1, 2],
"multi_modal_data": None
},
block_size=block_size)
seq_group = SequenceGroup(request_id="1", seq_group = SequenceGroup(request_id="1",
seqs=[parent], seqs=[parent],
arrival_time=time.time(), arrival_time=time.time(),

View File

@ -21,7 +21,13 @@ def create_dummy_prompt(
# and prompt "0 ... block_size". # and prompt "0 ... block_size".
prompt_tokens = list(range(prompt_length)) prompt_tokens = list(range(prompt_length))
prompt_str = " ".join([str(t) for t in prompt_tokens]) prompt_str = " ".join([str(t) for t in prompt_tokens])
prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size) prompt = Sequence(int(request_id),
inputs={
"prompt": prompt_str,
"prompt_token_ids": prompt_tokens,
"multi_modal_data": None,
},
block_size=block_size)
seq_group = SequenceGroup(request_id=request_id, seq_group = SequenceGroup(request_id=request_id,
seqs=[prompt], seqs=[prompt],
arrival_time=time.time(), arrival_time=time.time(),
@ -51,8 +57,11 @@ def create_seq_group(
for seq_id_offset, output_len in enumerate(seq_output_lens): for seq_id_offset, output_len in enumerate(seq_output_lens):
seq = Sequence( seq = Sequence(
seq_id=seq_id_start + seq_id_offset, seq_id=seq_id_start + seq_id_offset,
prompt="", inputs={
prompt_token_ids=prompt_token_ids, "prompt": "",
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": None,
},
block_size=16, block_size=16,
) )

View File

@ -14,7 +14,7 @@ def test_skip_tokenizer_initialization(model: str):
with pytest.raises(ValueError) as err: with pytest.raises(ValueError) as err:
llm.generate("abc", sampling_params) llm.generate("abc", sampling_params)
assert "prompts must be None if" in str(err.value) assert "prompts must be None if" in str(err.value)
outputs = llm.generate(prompt_token_ids=[[1, 2, 3]], outputs = llm.generate({"prompt_token_ids": [1, 2, 3]},
sampling_params=sampling_params) sampling_params=sampling_params)
assert len(outputs) > 0 assert len(outputs) > 0
completions = outputs[0].outputs completions = outputs[0].outputs

View File

@ -1,11 +1,15 @@
import asyncio import asyncio
from dataclasses import dataclass from dataclasses import dataclass
import pytest
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
MODEL_NAME = "openai-community/gpt2" MODEL_NAME = "openai-community/gpt2"
CHAT_TEMPLATE = "Dummy chat template for testing {}" CHAT_TEMPLATE = "Dummy chat template for testing {}"
pytestmark = pytest.mark.openai
@dataclass @dataclass
class MockModelConfig: class MockModelConfig:

View File

@ -52,6 +52,8 @@ TEST_SCHEMA = {
TEST_REGEX = (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" TEST_REGEX = (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)") r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
pytestmark = pytest.mark.openai
def test_guided_logits_processors(): def test_guided_logits_processors():
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""

View File

@ -0,0 +1,144 @@
import weakref
from typing import List
import pytest
from vllm import LLM, EmbeddingRequestOutput, PoolingParams
from ..conftest import cleanup
MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
PROMPTS = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
TOKEN_IDS = [
# Using ID={0, 1, 2, 3} results in NaN values,
# so we add this offset of 1000
[1000],
[1000, 1001],
[1000, 1002, 1001],
[1000, 1003, 1001, 1002],
]
pytestmark = pytest.mark.llm
@pytest.fixture(scope="module")
def llm():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(model=MODEL_NAME,
max_num_batched_tokens=32768,
tensor_parallel_size=1,
gpu_memory_utilization=0.75,
enforce_eager=True)
with llm.deprecate_legacy_api():
yield weakref.proxy(llm)
del llm
cleanup()
def assert_outputs_equal(o1: List[EmbeddingRequestOutput],
o2: List[EmbeddingRequestOutput]):
assert [o.outputs for o in o1] == [o.outputs for o in o2]
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt', PROMPTS)
def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
pooling_params = PoolingParams()
with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.encode(prompts=prompt, pooling_params=pooling_params)
v2_output = llm.encode(prompt, pooling_params=pooling_params)
assert_outputs_equal(v1_output, v2_output)
v2_output = llm.encode({"prompt": prompt}, pooling_params=pooling_params)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
prompt_token_ids):
pooling_params = PoolingParams()
with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
v1_output = llm.encode(prompt_token_ids=prompt_token_ids,
pooling_params=pooling_params)
v2_output = llm.encode({"prompt_token_ids": prompt_token_ids},
pooling_params=pooling_params)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
pooling_params = PoolingParams()
with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.encode(prompts=PROMPTS, pooling_params=pooling_params)
v2_output = llm.encode(PROMPTS, pooling_params=pooling_params)
assert_outputs_equal(v1_output, v2_output)
v2_output = llm.encode(
[{
"prompt": p
} for p in PROMPTS],
pooling_params=pooling_params,
)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
pooling_params = PoolingParams()
with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
v1_output = llm.encode(prompt_token_ids=TOKEN_IDS,
pooling_params=pooling_params)
v2_output = llm.encode(
[{
"prompt_token_ids": p
} for p in TOKEN_IDS],
pooling_params=pooling_params,
)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
def test_multiple_pooling_params(llm: LLM):
pooling_params = [
PoolingParams(),
PoolingParams(),
PoolingParams(),
PoolingParams(),
]
# Multiple PoolingParams should be matched with each prompt
outputs = llm.encode(PROMPTS, pooling_params=pooling_params)
assert len(PROMPTS) == len(outputs)
# Exception raised, if the size of params does not match the size of prompts
with pytest.raises(ValueError):
outputs = llm.encode(PROMPTS, pooling_params=pooling_params[:3])
# Single PoolingParams should be applied to every prompt
single_pooling_params = PoolingParams()
outputs = llm.encode(PROMPTS, pooling_params=single_pooling_params)
assert len(PROMPTS) == len(outputs)
# pooling_params is None, default params should be applied
outputs = llm.encode(PROMPTS, pooling_params=None)
assert len(PROMPTS) == len(outputs)

View File

@ -1,21 +1,124 @@
import weakref
from typing import List
import pytest import pytest
from vllm import LLM, SamplingParams from vllm import LLM, RequestOutput, SamplingParams
from ..conftest import cleanup
def test_multiple_sampling_params(): MODEL_NAME = "facebook/opt-125m"
llm = LLM(model="facebook/opt-125m", PROMPTS = [
max_num_batched_tokens=4096,
tensor_parallel_size=1)
prompts = [
"Hello, my name is", "Hello, my name is",
"The president of the United States is", "The president of the United States is",
"The capital of France is", "The capital of France is",
"The future of AI is", "The future of AI is",
] ]
TOKEN_IDS = [
[0],
[0, 1],
[0, 2, 1],
[0, 3, 1, 2],
]
pytestmark = pytest.mark.llm
@pytest.fixture(scope="module")
def llm():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(model=MODEL_NAME,
max_num_batched_tokens=4096,
tensor_parallel_size=1,
gpu_memory_utilization=0.10,
enforce_eager=True)
with llm.deprecate_legacy_api():
yield weakref.proxy(llm)
del llm
cleanup()
def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]):
assert [o.outputs for o in o1] == [o.outputs for o in o2]
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt', PROMPTS)
def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.generate(prompts=prompt,
sampling_params=sampling_params)
v2_output = llm.generate(prompt, sampling_params=sampling_params)
assert_outputs_equal(v1_output, v2_output)
v2_output = llm.generate({"prompt": prompt},
sampling_params=sampling_params)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
prompt_token_ids):
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
v1_output = llm.generate(prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params)
v2_output = llm.generate({"prompt_token_ids": prompt_token_ids},
sampling_params=sampling_params)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.generate(prompts=PROMPTS,
sampling_params=sampling_params)
v2_output = llm.generate(PROMPTS, sampling_params=sampling_params)
assert_outputs_equal(v1_output, v2_output)
v2_output = llm.generate(
[{
"prompt": p
} for p in PROMPTS],
sampling_params=sampling_params,
)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
v1_output = llm.generate(prompt_token_ids=TOKEN_IDS,
sampling_params=sampling_params)
v2_output = llm.generate(
[{
"prompt_token_ids": p
} for p in TOKEN_IDS],
sampling_params=sampling_params,
)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
def test_multiple_sampling_params(llm: LLM):
sampling_params = [ sampling_params = [
SamplingParams(temperature=0.01, top_p=0.95), SamplingParams(temperature=0.01, top_p=0.95),
SamplingParams(temperature=0.3, top_p=0.95), SamplingParams(temperature=0.3, top_p=0.95),
@ -24,18 +127,18 @@ def test_multiple_sampling_params():
] ]
# Multiple SamplingParams should be matched with each prompt # Multiple SamplingParams should be matched with each prompt
outputs = llm.generate(prompts, sampling_params=sampling_params) outputs = llm.generate(PROMPTS, sampling_params=sampling_params)
assert len(prompts) == len(outputs) assert len(PROMPTS) == len(outputs)
# Exception raised, if the size of params does not match the size of prompts # Exception raised, if the size of params does not match the size of prompts
with pytest.raises(ValueError): with pytest.raises(ValueError):
outputs = llm.generate(prompts, sampling_params=sampling_params[:3]) outputs = llm.generate(PROMPTS, sampling_params=sampling_params[:3])
# Single SamplingParams should be applied to every prompt # Single SamplingParams should be applied to every prompt
single_sampling_params = SamplingParams(temperature=0.3, top_p=0.95) single_sampling_params = SamplingParams(temperature=0.3, top_p=0.95)
outputs = llm.generate(prompts, sampling_params=single_sampling_params) outputs = llm.generate(PROMPTS, sampling_params=single_sampling_params)
assert len(prompts) == len(outputs) assert len(PROMPTS) == len(outputs)
# sampling_params is None, default params should be applied # sampling_params is None, default params should be applied
outputs = llm.generate(prompts, sampling_params=None) outputs = llm.generate(PROMPTS, sampling_params=None)
assert len(prompts) == len(outputs) assert len(PROMPTS) == len(outputs)

View File

@ -71,7 +71,7 @@ TEST_CHOICE = [
"Swift", "Kotlin" "Swift", "Kotlin"
] ]
pytestmark = pytest.mark.asyncio pytestmark = pytest.mark.openai
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -91,6 +91,8 @@ def server(zephyr_lora_files):
"--max-model-len", "--max-model-len",
"8192", "8192",
"--enforce-eager", "--enforce-eager",
"--gpu-memory-utilization",
"0.75",
# lora config below # lora config below
"--enable-lora", "--enable-lora",
"--lora-modules", "--lora-modules",
@ -118,9 +120,11 @@ def embedding_server(zephyr_lora_files):
# use half precision for speed and memory savings in CI environment # use half precision for speed and memory savings in CI environment
"--dtype", "--dtype",
"bfloat16", "bfloat16",
"--enforce-eager",
"--gpu-memory-utilization",
"0.75",
"--max-model-len", "--max-model-len",
"8192", "8192",
"--enforce-eager",
]) ])
ray.get(server_runner.ready.remote()) ray.get(server_runner.ready.remote())
yield server_runner yield server_runner
@ -136,6 +140,7 @@ def client():
yield client yield client
@pytest.mark.asyncio
async def test_check_models(server, client: openai.AsyncOpenAI): async def test_check_models(server, client: openai.AsyncOpenAI):
models = await client.models.list() models = await client.models.list()
models = models.data models = models.data
@ -147,6 +152,7 @@ async def test_check_models(server, client: openai.AsyncOpenAI):
assert lora_models[1].id == "zephyr-lora2" assert lora_models[1].id == "zephyr-lora2"
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
# first test base model, then test loras # first test base model, then test loras
"model_name", "model_name",
@ -178,6 +184,7 @@ async def test_single_completion(server, client: openai.AsyncOpenAI,
completion.choices[0].text) >= 5 completion.choices[0].text) >= 5
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
# first test base model, then test loras # first test base model, then test loras
"model_name", "model_name",
@ -199,6 +206,7 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI,
assert choice.logprobs.top_logprobs is None assert choice.logprobs.top_logprobs is None
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
# just test 1 lora hereafter # just test 1 lora hereafter
"model_name", "model_name",
@ -243,6 +251,7 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
assert message.content is not None and len(message.content) >= 0 assert message.content is not None and len(message.content) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_too_many_logprobs(server, client: openai.AsyncOpenAI, async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
model_name: str): model_name: str):
@ -298,6 +307,7 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
assert message.content is not None and len(message.content) >= 0 assert message.content is not None and len(message.content) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
# just test 1 lora hereafter # just test 1 lora hereafter
"model_name", "model_name",
@ -335,6 +345,7 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI,
assert "".join(chunks) == single_output assert "".join(chunks) == single_output
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
# just test 1 lora hereafter # just test 1 lora hereafter
"model_name", "model_name",
@ -385,6 +396,7 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI,
assert "".join(chunks) == output assert "".join(chunks) == output
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
# just test 1 lora hereafter # just test 1 lora hereafter
"model_name", "model_name",
@ -438,6 +450,7 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI,
assert texts[0] == texts[1] assert texts[0] == texts[1]
@pytest.mark.asyncio
async def test_logits_bias(server, client: openai.AsyncOpenAI): async def test_logits_bias(server, client: openai.AsyncOpenAI):
prompt = "Hello, my name is" prompt = "Hello, my name is"
max_tokens = 5 max_tokens = 5
@ -485,6 +498,7 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
assert first_response != completion.choices[0].text assert first_response != completion.choices[0].text
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_json_completion(server, client: openai.AsyncOpenAI, async def test_guided_json_completion(server, client: openai.AsyncOpenAI,
@ -507,6 +521,7 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI,
jsonschema.validate(instance=output_json, schema=TEST_SCHEMA) jsonschema.validate(instance=output_json, schema=TEST_SCHEMA)
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_json_chat(server, client: openai.AsyncOpenAI, async def test_guided_json_chat(server, client: openai.AsyncOpenAI,
@ -553,6 +568,7 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI,
assert json1["age"] != json2["age"] assert json1["age"] != json2["age"]
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_regex_completion(server, client: openai.AsyncOpenAI, async def test_guided_regex_completion(server, client: openai.AsyncOpenAI,
@ -573,6 +589,7 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI,
assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_regex_chat(server, client: openai.AsyncOpenAI, async def test_guided_regex_chat(server, client: openai.AsyncOpenAI,
@ -610,6 +627,7 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI,
assert ip1 != ip2 assert ip1 != ip2
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_choice_completion(server, client: openai.AsyncOpenAI, async def test_guided_choice_completion(server, client: openai.AsyncOpenAI,
@ -629,6 +647,7 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI,
assert completion.choices[i].text in TEST_CHOICE assert completion.choices[i].text in TEST_CHOICE
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_choice_chat(server, client: openai.AsyncOpenAI, async def test_guided_choice_chat(server, client: openai.AsyncOpenAI,
@ -667,6 +686,7 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI,
assert choice1 != choice2 assert choice1 != choice2
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI, async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI,
@ -702,6 +722,7 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI,
extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA)) extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA))
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend", @pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"]) ["outlines", "lm-format-enforcer"])
async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI, async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
@ -732,6 +753,7 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
for token, logprob in token_dict.items()) for token, logprob in token_dict.items())
@pytest.mark.asyncio
async def test_response_format_json_object(server, client: openai.AsyncOpenAI): async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
for _ in range(2): for _ in range(2):
resp = await client.chat.completions.create( resp = await client.chat.completions.create(
@ -749,6 +771,7 @@ async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
assert loaded == {"result": 2}, loaded assert loaded == {"result": 2}, loaded
@pytest.mark.asyncio
async def test_extra_fields(server, client: openai.AsyncOpenAI): async def test_extra_fields(server, client: openai.AsyncOpenAI):
with pytest.raises(BadRequestError) as exc_info: with pytest.raises(BadRequestError) as exc_info:
await client.chat.completions.create( await client.chat.completions.create(
@ -764,6 +787,7 @@ async def test_extra_fields(server, client: openai.AsyncOpenAI):
assert "extra_forbidden" in exc_info.value.message assert "extra_forbidden" in exc_info.value.message
@pytest.mark.asyncio
async def test_complex_message_content(server, client: openai.AsyncOpenAI): async def test_complex_message_content(server, client: openai.AsyncOpenAI):
resp = await client.chat.completions.create( resp = await client.chat.completions.create(
model=MODEL_NAME, model=MODEL_NAME,
@ -783,6 +807,7 @@ async def test_complex_message_content(server, client: openai.AsyncOpenAI):
assert content == "2" assert content == "2"
@pytest.mark.asyncio
async def test_custom_role(server, client: openai.AsyncOpenAI): async def test_custom_role(server, client: openai.AsyncOpenAI):
# Not sure how the model handles custom roles so we just check that # Not sure how the model handles custom roles so we just check that
# both string and complex message content are handled in the same way # both string and complex message content are handled in the same way
@ -813,6 +838,7 @@ async def test_custom_role(server, client: openai.AsyncOpenAI):
assert content1 == content2 assert content1 == content2
@pytest.mark.asyncio
async def test_guided_grammar(server, client: openai.AsyncOpenAI): async def test_guided_grammar(server, client: openai.AsyncOpenAI):
simple_sql_grammar = """ simple_sql_grammar = """
start: select_statement start: select_statement
@ -847,6 +873,7 @@ number: "1" | "2"
assert content.strip() == ground_truth assert content.strip() == ground_truth
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
# first test base model, then test loras # first test base model, then test loras
"model_name", "model_name",
@ -878,6 +905,7 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
assert len(logprobs.tokens) > 5 assert len(logprobs.tokens) > 5
@pytest.mark.asyncio
async def test_long_seed(server, client: openai.AsyncOpenAI): async def test_long_seed(server, client: openai.AsyncOpenAI):
for seed in [ for seed in [
torch.iinfo(torch.long).min - 1, torch.iinfo(torch.long).min - 1,
@ -897,6 +925,7 @@ async def test_long_seed(server, client: openai.AsyncOpenAI):
or "less_than_equal" in exc_info.value.message) or "less_than_equal" in exc_info.value.message)
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",
[EMBEDDING_MODEL_NAME], [EMBEDDING_MODEL_NAME],
@ -935,6 +964,7 @@ async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI,
assert embeddings.usage.total_tokens == 5 assert embeddings.usage.total_tokens == 5
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_name", "model_name",
[EMBEDDING_MODEL_NAME], [EMBEDDING_MODEL_NAME],

View File

@ -1,7 +1,7 @@
import multiprocessing
import sys import sys
import time import time
import pytest
import torch import torch
from openai import OpenAI, OpenAIError from openai import OpenAI, OpenAIError
@ -10,6 +10,8 @@ from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.utils import get_open_port from vllm.utils import get_open_port
pytestmark = pytest.mark.openai
class MyOPTForCausalLM(OPTForCausalLM): class MyOPTForCausalLM(OPTForCausalLM):
@ -26,15 +28,16 @@ def server_function(port):
# register our dummy model # register our dummy model
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM) ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
sys.argv = ["placeholder.py"] + \ sys.argv = ["placeholder.py"] + \
("--model facebook/opt-125m --dtype" ("--model facebook/opt-125m --gpu-memory-utilization 0.10 "
f" float32 --api-key token-abc123 --port {port}").split() f"--dtype float32 --api-key token-abc123 --port {port}").split()
import runpy import runpy
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__') runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')
def test_oot_registration_for_api_server(): def test_oot_registration_for_api_server():
port = get_open_port() port = get_open_port()
server = multiprocessing.Process(target=server_function, args=(port, )) ctx = torch.multiprocessing.get_context()
server = ctx.Process(target=server_function, args=(port, ))
server.start() server.start()
client = OpenAI( client = OpenAI(
base_url=f"http://localhost:{port}/v1", base_url=f"http://localhost:{port}/v1",

View File

@ -86,20 +86,18 @@ def generate(
def batched_generate( def batched_generate(
llm, llm: vllm.LLM,
inputs: List[Tuple[str, SamplingParams, Optional[LoRARequest]]], inputs: List[Tuple[str, SamplingParams, Optional[LoRARequest]]],
): ):
for input in inputs: for input in inputs:
prompt, sampling_param, lora_req = input prompt, sampling_param, lora_req = input
requests_data = llm._validate_and_prepare_requests( # Add requests to the engine and run the engine
llm._validate_and_add_requests(
prompt, prompt,
sampling_param, sampling_param,
lora_request=lora_req, lora_request=lora_req,
) )
# Add requests to the engine and run the engine
for request_data in requests_data:
llm._add_request(**request_data)
outputs = llm._run_engine(use_tqdm=True) outputs = llm._run_engine(use_tqdm=True)
return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))] return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))]

View File

@ -35,28 +35,25 @@ def test_logits_processor_force_generate(
# test logits_processors when prompt_logprobs is not None # test logits_processors when prompt_logprobs is not None
vllm_model.model._add_request( vllm_model.model._add_request(
prompt=example_prompts[0], example_prompts[0],
params=params_with_logprobs, params=params_with_logprobs,
prompt_token_ids=None,
) )
# test prompt_logprobs is not None # test prompt_logprobs is not None
vllm_model.model._add_request( vllm_model.model._add_request(
prompt=example_prompts[1], example_prompts[1],
params=SamplingParams( params=SamplingParams(
prompt_logprobs=3, prompt_logprobs=3,
max_tokens=max_tokens, max_tokens=max_tokens,
), ),
prompt_token_ids=None,
) )
# test grouped requests # test grouped requests
vllm_model.model._add_request( vllm_model.model._add_request(
prompt=example_prompts[2], example_prompts[2],
params=SamplingParams(max_tokens=max_tokens), params=SamplingParams(max_tokens=max_tokens),
prompt_token_ids=None,
) )
outputs = vllm_model.model._run_engine(False) outputs = vllm_model.model._run_engine(use_tqdm=False)
assert outputs[0].outputs[0].text == enforced_answers * repeat_times assert outputs[0].outputs[0].text == enforced_answers * repeat_times

View File

@ -57,11 +57,7 @@ def test_random_sample_with_seed(
sampling_params_seed_1, sampling_params_seed_1,
sampling_params_seed_2, sampling_params_seed_2,
): ):
llm._add_request( llm._add_request(prompt, params=params)
prompt=prompt,
prompt_token_ids=None,
params=params,
)
results = llm._run_engine(use_tqdm=False) results = llm._run_engine(use_tqdm=False)
all_outputs = [[out.token_ids for out in output.outputs] all_outputs = [[out.token_ids for out in output.outputs]

View File

@ -70,8 +70,15 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
for prompt in prompts: for prompt in prompts:
hashes[-1].append([]) hashes[-1].append([])
prompt_token_ids = tokenizer.encode(prompt) prompt_token_ids = tokenizer.encode(prompt)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, seq = Sequence(seq_id,
tokenizer.tokenizer.eos_token_id, lora_request) inputs={
"prompt": prompt,
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": None,
},
block_size=block_size,
eos_token_id=tokenizer.tokenizer.eos_token_id,
lora_request=lora_request)
num_blocks = len(prompt_token_ids) // block_size num_blocks = len(prompt_token_ids) // block_size
for idx in range(num_blocks): for idx in range(num_blocks):

53
tests/test_inputs.py Normal file
View File

@ -0,0 +1,53 @@
from typing import List
import pytest
from vllm.inputs import parse_and_batch_prompt
STRING_INPUTS = [
'',
'foo',
'foo bar',
'foo baz bar',
'foo bar qux baz',
]
TOKEN_INPUTS = [
[-1],
[1],
[1, 2],
[1, 3, 4],
[1, 2, 4, 3],
]
INPUTS_SLICES = [
slice(None, None, -1),
slice(None, None, 2),
slice(None, None, -2),
]
def test_parse_single_batch_empty():
with pytest.raises(ValueError, match="at least one prompt"):
parse_and_batch_prompt([])
with pytest.raises(ValueError, match="at least one prompt"):
parse_and_batch_prompt([[]])
@pytest.mark.parametrize('string_input', STRING_INPUTS)
def test_parse_single_batch_string_consistent(string_input: str):
assert parse_and_batch_prompt(string_input) \
== parse_and_batch_prompt([string_input])
@pytest.mark.parametrize('token_input', TOKEN_INPUTS)
def test_parse_single_batch_token_consistent(token_input: List[int]):
assert parse_and_batch_prompt(token_input) \
== parse_and_batch_prompt([token_input])
@pytest.mark.parametrize('inputs_slice', INPUTS_SLICES)
def test_parse_single_batch_string_slice(inputs_slice: slice):
assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \
== parse_and_batch_prompt(STRING_INPUTS[inputs_slice])

63
tests/test_utils.py Normal file
View File

@ -0,0 +1,63 @@
import pytest
from vllm.utils import deprecate_kwargs
from .utils import error_on_warning
def test_deprecate_kwargs_always():
@deprecate_kwargs("old_arg", is_deprecated=True)
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with pytest.warns(DeprecationWarning, match="'old_arg'"):
dummy(old_arg=1)
with error_on_warning():
dummy(new_arg=1)
def test_deprecate_kwargs_never():
@deprecate_kwargs("old_arg", is_deprecated=False)
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with error_on_warning():
dummy(old_arg=1)
with error_on_warning():
dummy(new_arg=1)
def test_deprecate_kwargs_dynamic():
is_deprecated = True
@deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated)
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with pytest.warns(DeprecationWarning, match="'old_arg'"):
dummy(old_arg=1)
with error_on_warning():
dummy(new_arg=1)
is_deprecated = False
with error_on_warning():
dummy(old_arg=1)
with error_on_warning():
dummy(new_arg=1)
def test_deprecate_kwargs_additional_message():
@deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd")
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with pytest.warns(DeprecationWarning, match="abcd"):
dummy(old_arg=1)

View File

@ -123,8 +123,11 @@ def create_sequence(prompt_token_ids=None):
prompt_token_ids = prompt_token_ids or [1] prompt_token_ids = prompt_token_ids or [1]
return Sequence( return Sequence(
seq_id=0, seq_id=0,
prompt="<s>", inputs={
prompt_token_ids=prompt_token_ids, "prompt": "<s>",
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": None,
},
block_size=16, block_size=16,
) )

View File

@ -2,6 +2,8 @@ import os
import subprocess import subprocess
import sys import sys
import time import time
import warnings
from contextlib import contextmanager
import ray import ray
import requests import requests
@ -87,3 +89,15 @@ def multi_process_tensor_parallel(
ray.get(refs) ray.get(refs)
ray.shutdown() ray.shutdown()
@contextmanager
def error_on_warning():
"""
Within the scope of this context manager, tests will fail if any warning
is emitted.
"""
with warnings.catch_warnings():
warnings.simplefilter("error")
yield

View File

@ -5,6 +5,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptStrictInputs, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (CompletionOutput, EmbeddingOutput, from vllm.outputs import (CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput, RequestOutput) EmbeddingRequestOutput, RequestOutput)
@ -16,6 +17,9 @@ __version__ = "0.4.2"
__all__ = [ __all__ = [
"LLM", "LLM",
"ModelRegistry", "ModelRegistry",
"PromptStrictInputs",
"TextPrompt",
"TokensPrompt",
"SamplingParams", "SamplingParams",
"RequestOutput", "RequestOutput",
"CompletionOutput", "CompletionOutput",

View File

@ -12,12 +12,13 @@ from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.executor.ray_utils import initialize_ray_cluster, ray from vllm.executor.ray_utils import initialize_ray_cluster, ray
from vllm.inputs import LLMInputs, PromptInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
logger = init_logger(__name__) logger = init_logger(__name__)
@ -244,64 +245,69 @@ class _AsyncLLMEngine(LLMEngine):
return request_outputs return request_outputs
async def encode_request_async( async def process_model_inputs_async(
self, self,
request_id: str, # pylint: disable=unused-argument request_id: str,
prompt: Optional[str], inputs: PromptInputs,
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
): ) -> LLMInputs:
if prompt_token_ids is None: if isinstance(inputs, str):
assert prompt is not None inputs = {"prompt": inputs}
prompt_token_ids = await self.tokenizer.encode_async(
if "prompt_token_ids" not in inputs:
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")
prompt_token_ids = await tokenizer.encode_async(
request_id=request_id, request_id=request_id,
prompt=prompt, prompt=inputs["prompt"],
lora_request=lora_request) lora_request=lora_request)
return prompt_token_ids else:
prompt_token_ids = inputs["prompt_token_ids"]
return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
async def add_request_async( async def add_request_async(
self, self,
request_id: str, request_id: str,
prompt: Optional[str], inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None: ) -> None:
if lora_request is not None and not self.lora_config: if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is " raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!") "not enabled!")
if arrival_time is None: if arrival_time is None:
arrival_time = time.time() arrival_time = time.time()
prompt_token_ids = await self.encode_request_async(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
return self.add_request(request_id, processed_inputs = await self.process_model_inputs_async(
prompt=prompt, request_id=request_id, inputs=inputs, lora_request=lora_request)
self._add_processed_request(
request_id=request_id,
processed_inputs=processed_inputs,
params=params, params=params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
multi_modal_data=multi_modal_data) )
async def check_health_async(self) -> None: async def check_health_async(self) -> None:
self.model_executor.check_health() self.model_executor.check_health()
class AsyncLLMEngine: class AsyncLLMEngine:
"""An asynchronous wrapper for LLMEngine. """An asynchronous wrapper for :class:`LLMEngine`.
This class is used to wrap the LLMEngine class to make it asynchronous. It This class is used to wrap the :class:`LLMEngine` class to make it
uses asyncio to create a background loop that keeps processing incoming asynchronous. It uses asyncio to create a background loop that keeps
requests. The LLMEngine is kicked by the generate method when there processing incoming requests. The :class:`LLMEngine` is kicked by the
are requests in the waiting queue. The generate method yields the outputs generate method when there are requests in the waiting queue. The generate
from the LLMEngine to the caller. method yields the outputs from the :class:`LLMEngine` to the caller.
NOTE: For the comprehensive list of arguments, see `LLMEngine`. NOTE: For the comprehensive list of arguments, see :class:`LLMEngine`.
Args: Args:
worker_use_ray: Whether to use Ray for model workers. Required for worker_use_ray: Whether to use Ray for model workers. Required for
@ -315,8 +321,8 @@ class AsyncLLMEngine:
being printed in log. being printed in log.
start_engine_loop: If True, the background task to run the engine start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call. will be automatically started in the generate call.
*args: Arguments for LLMEngine. *args: Arguments for :class:`LLMEngine`.
*kwargs: Arguments for LLMEngine. **kwargs: Arguments for :class:`LLMEngine`.
""" """
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
@ -526,22 +532,26 @@ class AsyncLLMEngine:
async def add_request( async def add_request(
self, self,
request_id: str, request_id: str,
prompt: Optional[str], inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> AsyncStream: ) -> AsyncStream:
if self.log_requests: if self.log_requests:
shortened_prompt = prompt if isinstance(inputs, str):
shortened_token_ids = prompt_token_ids shortened_prompt = inputs
if self.max_log_len is not None: shortened_token_ids = None
else:
shortened_prompt = inputs.get("prompt")
shortened_token_ids = inputs.get("prompt_token_ids")
max_log_len = self.max_log_len
if max_log_len is not None:
if shortened_prompt is not None: if shortened_prompt is not None:
shortened_prompt = shortened_prompt[:self.max_log_len] shortened_prompt = shortened_prompt[:max_log_len]
if shortened_token_ids is not None: if shortened_token_ids is not None:
shortened_token_ids = shortened_token_ids[:self. shortened_token_ids = shortened_token_ids[:max_log_len]
max_log_len]
logger.info( logger.info(
"Received request %s: prompt: %r, " "Received request %s: prompt: %r, "
"params: %s, prompt_token_ids: %s, " "params: %s, prompt_token_ids: %s, "
@ -562,39 +572,33 @@ class AsyncLLMEngine:
arrival_time = time.time() arrival_time = time.time()
if self.engine_use_ray: if self.engine_use_ray:
prompt_token_ids = await ( processed_inputs = await self.engine.process_model_inputs_async \
self.engine.encode_request_async.remote( # type: ignore .remote( # type: ignore
request_id=request_id, request_id=request_id,
prompt=prompt, inputs=inputs,
prompt_token_ids=prompt_token_ids, lora_request=lora_request)
lora_request=lora_request))
else: else:
prompt_token_ids = await self.engine.encode_request_async( processed_inputs = await self.engine.process_model_inputs_async(
request_id=request_id, request_id=request_id,
prompt=prompt, inputs=inputs,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request) lora_request=lora_request)
stream = self._request_tracker.add_request( stream = self._request_tracker.add_request(
request_id, request_id,
prompt=prompt, inputs=processed_inputs,
params=params, params=params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
multi_modal_data=multi_modal_data,
) )
return stream return stream
async def generate( async def generate(
self, self,
prompt: Optional[str], inputs: PromptInputs,
sampling_params: SamplingParams, sampling_params: SamplingParams,
request_id: str, request_id: str,
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None
) -> AsyncIterator[RequestOutput]: ) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request. """Generate outputs for a request.
@ -603,14 +607,12 @@ class AsyncLLMEngine:
from the LLMEngine to the caller. from the LLMEngine to the caller.
Args: Args:
prompt: The prompt string. Can be None if prompt_token_ids is inputs: The inputs to the LLM. See
provided. :class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
sampling_params: The sampling parameters of the request. sampling_params: The sampling parameters of the request.
request_id: The unique id of the request. request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data per request.
Yields: Yields:
The output `RequestOutput` objects from the LLMEngine The output `RequestOutput` objects from the LLMEngine
@ -659,24 +661,20 @@ class AsyncLLMEngine:
>>> # Process and return the final output >>> # Process and return the final output
>>> ... >>> ...
""" """
async for output in self.process_request( async for output in self._process_request(
request_id, request_id,
prompt, inputs,
sampling_params, sampling_params,
prompt_token_ids, lora_request=lora_request,
lora_request,
multi_modal_data,
): ):
yield output yield LLMEngine.validate_output(output, RequestOutput)
async def encode( async def encode(
self, self,
prompt: Optional[str], inputs: PromptInputs,
pooling_params: PoolingParams, pooling_params: PoolingParams,
request_id: str, request_id: str,
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None
) -> AsyncIterator[EmbeddingRequestOutput]: ) -> AsyncIterator[EmbeddingRequestOutput]:
"""Generate outputs for a request from an embedding model. """Generate outputs for a request from an embedding model.
@ -685,14 +683,12 @@ class AsyncLLMEngine:
from the LLMEngine to the caller. from the LLMEngine to the caller.
Args: Args:
prompt: The prompt string. Can be None if prompt_token_ids is inputs: The inputs to the LLM. See
provided. :class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
pooling_params: The pooling parameters of the request. pooling_params: The pooling parameters of the request.
request_id: The unique id of the request. request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data per request.
Yields: Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine The output `EmbeddingRequestOutput` objects from the LLMEngine
@ -739,24 +735,21 @@ class AsyncLLMEngine:
>>> # Process and return the final output >>> # Process and return the final output
>>> ... >>> ...
""" """
async for output in self.process_request( async for output in self._process_request(
request_id, request_id,
prompt, inputs,
pooling_params, pooling_params,
prompt_token_ids, lora_request=lora_request,
lora_request,
multi_modal_data,
): ):
yield output yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
async def process_request( async def _process_request(
self, self,
request_id: str, request_id: str,
prompt: Optional[str], inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None, *,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]: ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Common logic to process requests with SamplingParams or """Common logic to process requests with SamplingParams or
PoolingParams.""" PoolingParams."""
@ -764,12 +757,10 @@ class AsyncLLMEngine:
stream = await self.add_request( stream = await self.add_request(
request_id, request_id,
prompt, inputs,
params, params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
multi_modal_data=multi_modal_data,
) )
try: try:

View File

@ -1,5 +1,8 @@
import time import time
from typing import Iterable, List, Optional, Type, Union from contextlib import contextmanager
from typing import TYPE_CHECKING, ClassVar, Iterable, List, Optional
from typing import Sequence as GenericSequence
from typing import Type, TypeVar, Union
from transformers import GenerationConfig, PreTrainedTokenizer from transformers import GenerationConfig, PreTrainedTokenizer
@ -18,6 +21,7 @@ from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import LLMInputs, PromptInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
@ -25,8 +29,8 @@ from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
MultiModalData, PoolerOutput, SamplerOutput, PoolerOutput, SamplerOutput, Sequence,
Sequence, SequenceGroup, SequenceGroupMetadata, SequenceGroup, SequenceGroupMetadata,
SequenceStatus) SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
@ -50,6 +54,9 @@ def _load_generation_config_dict(model_config: ModelConfig):
return {} return {}
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
class LLMEngine: class LLMEngine:
"""An LLM engine that receives requests and generates texts. """An LLM engine that receives requests and generates texts.
@ -60,11 +67,11 @@ class LLMEngine:
iteration-level scheduling and efficient memory management to maximize the iteration-level scheduling and efficient memory management to maximize the
serving throughput. serving throughput.
The `LLM` class wraps this class for offline batched inference and the The :class:`~vllm.LLM` class wraps this class for offline batched inference
`AsyncLLMEngine` class wraps this class for online serving. and the :class:`AsyncLLMEngine` class wraps this class for online serving.
NOTE: The config arguments are derived from the `EngineArgs` class. For the NOTE: The config arguments are derived from the :class:`~vllm.EngineArgs`
comprehensive list of arguments, see `EngineArgs`. class. For the comprehensive list of arguments, see :ref:`engine_args`.
Args: Args:
model_config: The configuration related to the LLM model. model_config: The configuration related to the LLM model.
@ -81,9 +88,60 @@ class LLMEngine:
executor_class: The model executor class for managing distributed executor_class: The model executor class for managing distributed
execution. execution.
log_stats: Whether to log statistics. log_stats: Whether to log statistics.
usage_context: Specified entry point, used for usage info collection usage_context: Specified entry point, used for usage info collection.
""" """
DO_VALIDATE_OUTPUT: ClassVar[bool] = False
"""A flag to toggle whether to validate the type of request output."""
@classmethod
@contextmanager
def enable_output_validation(cls):
cls.DO_VALIDATE_OUTPUT = True
yield
cls.DO_VALIDATE_OUTPUT = False
@classmethod
def validate_output(
cls,
output: object,
output_type: Type[_O],
) -> _O:
do_validate = cls.DO_VALIDATE_OUTPUT
if ((TYPE_CHECKING or do_validate)
and not isinstance(output, output_type)):
raise TypeError(f"Expected output of type {output_type}, "
f"but found type {type(output)}")
return output
@classmethod
def validate_outputs(
cls,
outputs: GenericSequence[object],
output_type: Type[_O],
) -> List[_O]:
do_validate = cls.DO_VALIDATE_OUTPUT
outputs_: List[_O]
if TYPE_CHECKING or do_validate:
outputs_ = []
for output in outputs:
if not isinstance(output, output_type):
raise TypeError(f"Expected output of type {output_type}, "
f"but found type {type(output)}")
outputs_.append(output)
else:
outputs_ = outputs
return outputs_
tokenizer: Optional[BaseTokenizerGroup]
def __init__( def __init__(
self, self,
model_config: ModelConfig, model_config: ModelConfig,
@ -151,12 +209,11 @@ class LLMEngine:
self.log_stats = log_stats self.log_stats = log_stats
if not self.model_config.skip_tokenizer_init: if not self.model_config.skip_tokenizer_init:
self.tokenizer: BaseTokenizerGroup self.tokenizer = self._init_tokenizer()
self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer) self.detokenizer = Detokenizer(self.tokenizer)
else: else:
self.detokenizer = None
self.tokenizer = None self.tokenizer = None
self.detokenizer = None
self.seq_counter = Counter() self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict( self.generation_config_fields = _load_generation_config_dict(
@ -318,14 +375,26 @@ class LLMEngine:
if model_executor := getattr(self, "model_executor", None): if model_executor := getattr(self, "model_executor", None):
model_executor.shutdown() model_executor.shutdown()
MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because "
"skip_tokenizer_init is True")
def get_tokenizer_group(
self,
fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup:
if self.tokenizer is None:
raise ValueError(fail_msg)
return self.tokenizer
def get_tokenizer(self) -> "PreTrainedTokenizer": def get_tokenizer(self) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer(None) return self.get_tokenizer_group().get_lora_tokenizer(None)
def get_tokenizer_for_seq(self, def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer": sequence: Sequence) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer(sequence.lora_request) return self.get_tokenizer_group().get_lora_tokenizer(
sequence.lora_request)
def _init_tokenizer(self, **tokenizer_init_kwargs): def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup:
init_kwargs = dict( init_kwargs = dict(
tokenizer_id=self.model_config.tokenizer, tokenizer_id=self.model_config.tokenizer,
enable_lora=bool(self.lora_config), enable_lora=bool(self.lora_config),
@ -335,8 +404,9 @@ class LLMEngine:
trust_remote_code=self.model_config.trust_remote_code, trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision) revision=self.model_config.tokenizer_revision)
init_kwargs.update(tokenizer_init_kwargs) init_kwargs.update(tokenizer_init_kwargs)
self.tokenizer = get_tokenizer_group(
self.parallel_config.tokenizer_pool_config, **init_kwargs) return get_tokenizer_group(self.parallel_config.tokenizer_pool_config,
**init_kwargs)
def _verify_args(self) -> None: def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config) self.model_config.verify_with_parallel_config(self.parallel_config)
@ -346,29 +416,85 @@ class LLMEngine:
self.lora_config.verify_with_scheduler_config( self.lora_config.verify_with_scheduler_config(
self.scheduler_config) self.scheduler_config)
def encode_request( def _get_eos_token_id(
self, lora_request: Optional[LoRARequest]) -> Optional[int]:
if self.tokenizer is None:
logger.warning("Using None for EOS token id because tokenizer "
"is not initialized")
return None
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
def _add_processed_request(
self, self,
request_id: str, # pylint: disable=unused-argument request_id: str,
prompt: Optional[str], processed_inputs: LLMInputs,
prompt_token_ids: Optional[List[int]] = None, params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest],
) -> None:
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
eos_token_id = self._get_eos_token_id(lora_request)
seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
lora_request)
# Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams):
seq_group = self._create_sequence_group_with_sampling(
request_id,
seq,
params,
arrival_time=arrival_time,
lora_request=lora_request,
)
elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling(
request_id,
seq,
params,
arrival_time=arrival_time,
lora_request=lora_request,
)
else:
raise ValueError(
"Either SamplingParams or PoolingParams must be provided.")
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
def process_model_inputs(
self,
request_id: str,
inputs: PromptInputs,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
): ) -> LLMInputs:
if prompt_token_ids is None: if isinstance(inputs, str):
assert prompt is not None inputs = {"prompt": inputs}
prompt_token_ids = self.tokenizer.encode(request_id=request_id,
prompt=prompt, if "prompt_token_ids" not in inputs:
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")
prompt_token_ids = tokenizer.encode(request_id=request_id,
prompt=inputs["prompt"],
lora_request=lora_request) lora_request=lora_request)
return prompt_token_ids else:
prompt_token_ids = inputs["prompt_token_ids"]
return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
def add_request( def add_request(
self, self,
request_id: str, request_id: str,
prompt: Optional[str], inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None: ) -> None:
"""Add a request to the engine's request pool. """Add a request to the engine's request pool.
@ -378,15 +504,14 @@ class LLMEngine:
Args: Args:
request_id: The unique ID of the request. request_id: The unique ID of the request.
prompt: The prompt string. Can be None if prompt_token_ids is inputs: The inputs to the LLM. See
provided. :class:`~vllm.inputs.PromptInputs`
params: Parameters for sampling or pooling. SamplingParams for more details about the format of each input.
for text generation. PoolingParams for pooling. params: Parameters for sampling or pooling.
prompt_token_ids: The token IDs of the prompt. If None, we :class:`~vllm.SamplingParams` for text generation.
use the tokenizer to convert the prompts to token IDs. :class:`~vllm.PoolingParams` for pooling.
arrival_time: The arrival time of the request. If None, we use arrival_time: The arrival time of the request. If None, we use
the current monotonic time. the current monotonic time.
multi_modal_data: Multi modal data per request.
Details: Details:
- Set arrival_time to the current time if it is None. - Set arrival_time to the current time if it is None.
@ -417,59 +542,26 @@ class LLMEngine:
"not enabled!") "not enabled!")
if arrival_time is None: if arrival_time is None:
arrival_time = time.time() arrival_time = time.time()
prompt_token_ids = self.encode_request(
request_id=request_id, processed_inputs = self.process_model_inputs(request_id=request_id,
prompt=prompt, inputs=inputs,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request) lora_request=lora_request)
# Create the sequences. self._add_processed_request(
block_size = self.cache_config.block_size request_id=request_id,
seq_id = next(self.seq_counter) processed_inputs=processed_inputs,
eos_token_id = None params=params,
if self.tokenizer: arrival_time=arrival_time,
eos_token_id = self.tokenizer.get_lora_tokenizer( lora_request=lora_request,
lora_request).eos_token_id
else:
logger.warning("Use None for EOS token id because tokenizer is "
"not initialized")
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
eos_token_id, lora_request)
# Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams):
seq_group = self._create_sequence_group_with_sampling(
request_id,
seq,
params,
arrival_time,
lora_request,
multi_modal_data,
) )
elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling(
request_id,
seq,
params,
arrival_time,
lora_request,
multi_modal_data,
)
else:
raise ValueError(
"Either SamplingParams or PoolingParams must be provided.")
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
def _create_sequence_group_with_sampling( def _create_sequence_group_with_sampling(
self, self,
request_id: str, request_id: str,
seq: Sequence, seq: Sequence,
sampling_params: SamplingParams, sampling_params: SamplingParams,
arrival_time: Optional[float] = None, arrival_time: float,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest],
multi_modal_data: Optional[MultiModalData] = None,
) -> SequenceGroup: ) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams.""" """Creates a SequenceGroup with SamplingParams."""
max_logprobs = self.get_model_config().max_logprobs max_logprobs = self.get_model_config().max_logprobs
@ -495,8 +587,7 @@ class LLMEngine:
seqs=[seq], seqs=[seq],
arrival_time=arrival_time, arrival_time=arrival_time,
sampling_params=sampling_params, sampling_params=sampling_params,
lora_request=lora_request, lora_request=lora_request)
multi_modal_data=multi_modal_data)
return seq_group return seq_group
@ -505,9 +596,8 @@ class LLMEngine:
request_id: str, request_id: str,
seq: Sequence, seq: Sequence,
pooling_params: PoolingParams, pooling_params: PoolingParams,
arrival_time: Optional[float] = None, arrival_time: float,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest],
multi_modal_data: Optional[MultiModalData] = None,
) -> SequenceGroup: ) -> SequenceGroup:
"""Creates a SequenceGroup with PoolingParams.""" """Creates a SequenceGroup with PoolingParams."""
# Defensive copy of PoolingParams, which are used by the pooler # Defensive copy of PoolingParams, which are used by the pooler
@ -517,7 +607,6 @@ class LLMEngine:
seqs=[seq], seqs=[seq],
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
multi_modal_data=multi_modal_data,
pooling_params=pooling_params) pooling_params=pooling_params)
return seq_group return seq_group
@ -570,7 +659,7 @@ class LLMEngine:
def _process_model_outputs( def _process_model_outputs(
self, self,
output: List[Union[SamplerOutput, PoolerOutput]], output: GenericSequence[Union[SamplerOutput, PoolerOutput]],
scheduled_seq_groups: List[ScheduledSequenceGroup], scheduled_seq_groups: List[ScheduledSequenceGroup],
ignored_seq_groups: List[SequenceGroup], ignored_seq_groups: List[SequenceGroup],
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
@ -585,7 +674,7 @@ class LLMEngine:
# Organize outputs by [sequence group][step] instead of # Organize outputs by [sequence group][step] instead of
# [step][sequence group]. # [step][sequence group].
output_by_sequence_group = create_output_by_sequence_group( output_by_sequence_group = create_output_by_sequence_group(
sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups)) output, num_seq_groups=len(scheduled_seq_groups))
# Update the scheduled sequence groups with the model outputs. # Update the scheduled sequence groups with the model outputs.
for scheduled_seq_group, outputs, seq_group_meta in zip( for scheduled_seq_group, outputs, seq_group_meta in zip(

View File

@ -1,18 +1,20 @@
from typing import List from typing import List
from typing import Sequence as GenericSequence
from typing import Union
from vllm.sequence import SamplerOutput, SequenceGroupOutput from vllm.sequence import PoolerOutput, SamplerOutput, SequenceGroupOutput
def create_output_by_sequence_group( def create_output_by_sequence_group(
sampler_outputs: List[SamplerOutput], outputs: GenericSequence[Union[SamplerOutput, PoolerOutput]],
num_seq_groups: int) -> List[List[SequenceGroupOutput]]: num_seq_groups: int) -> List[List[SequenceGroupOutput]]:
"""Helper method which transforms a 2d list organized by """Helper method which transforms a 2d list organized by
[step][sequence group] into [sequence group][step]. [step][sequence group] into [sequence group][step].
""" """
output_by_sequence_group: List[List[SamplerOutput]] = [ output_by_sequence_group: List[List[SequenceGroupOutput]] = [
[] for _ in range(num_seq_groups) [] for _ in range(num_seq_groups)
] ]
for step in sampler_outputs: for step in outputs:
for i, sequence_group_output in enumerate(step): for i, sequence_group_output in enumerate(step):
output_by_sequence_group[i].append(sequence_group_output) output_by_sequence_group[i].append(sequence_group_output)

View File

@ -1,11 +1,14 @@
from typing import List, Optional, Union from contextlib import contextmanager
from typing import ClassVar, List, Optional, Sequence, Union, cast, overload
import torch
from tqdm import tqdm from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.inputs import (PromptInputs, PromptStrictInputs, TextPrompt,
TextTokensPrompt, TokensPrompt,
parse_and_batch_prompt)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
@ -13,7 +16,7 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import MultiModalData from vllm.sequence import MultiModalData
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter from vllm.utils import Counter, deprecate_kwargs
logger = init_logger(__name__) logger = init_logger(__name__)
@ -28,8 +31,10 @@ class LLM:
mechanism and efficient memory management. mechanism and efficient memory management.
NOTE: This class is intended to be used for offline inference. For online NOTE: This class is intended to be used for offline inference. For online
serving, use the `AsyncLLMEngine` class instead. serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
NOTE: For the comprehensive list of arguments, see `EngineArgs`.
NOTE: For the comprehensive list of arguments, see
:class:`~vllm.EngineArgs`.
Args: Args:
model: The name or path of a HuggingFace Transformers model. model: The name or path of a HuggingFace Transformers model.
@ -81,6 +86,18 @@ class LLM:
disable_custom_all_reduce: See ParallelConfig disable_custom_all_reduce: See ParallelConfig
""" """
DEPRECATE_LEGACY: ClassVar[bool] = False
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
@classmethod
@contextmanager
def deprecate_legacy_api(cls):
cls.DEPRECATE_LEGACY = True
yield
cls.DEPRECATE_LEGACY = False
def __init__( def __init__(
self, self,
model: str, model: str,
@ -138,15 +155,101 @@ class LLM:
) -> None: ) -> None:
self.llm_engine.tokenizer.tokenizer = tokenizer self.llm_engine.tokenizer.tokenizer = tokenizer
@overload # LEGACY: single (prompt + optional token ids)
def generate( def generate(
self, self,
prompts: Optional[Union[str, List[str]]] = None, prompts: str,
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
prompt_token_ids: Optional[List[int]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
...
@overload # LEGACY: multi (prompt + optional token ids)
def generate(
self,
prompts: List[str],
sampling_params: Optional[Union[SamplingParams, sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None, List[SamplingParams]]] = None,
prompt_token_ids: Optional[List[List[int]]] = None, prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None, multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
...
@overload # LEGACY: single (token ids + optional prompt)
def generate(
self,
prompts: Optional[str] = None,
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
*,
prompt_token_ids: List[int],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
...
@overload # LEGACY: multi (token ids + optional prompt)
def generate(
self,
prompts: Optional[List[str]] = None,
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
*,
prompt_token_ids: List[List[int]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
...
@overload # LEGACY: single or multi token ids [pos-only]
def generate(
self,
prompts: None,
sampling_params: None,
prompt_token_ids: Union[List[int], List[List[int]]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
...
@overload
def generate(
self,
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
/, # We may enable `inputs` keyword after removing the old API
*,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
) -> List[RequestOutput]:
...
@deprecate_kwargs("prompts",
"prompt_token_ids",
"multi_modal_data",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter "
"instead.")
def generate(
self,
prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
Optional[Union[str, List[str]]]] = None,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]: ) -> List[RequestOutput]:
"""Generates the completions for the input prompts. """Generates the completions for the input prompts.
@ -155,49 +258,138 @@ class LLM:
into a single list and pass it to this method. into a single list and pass it to this method.
Args: Args:
prompts: A list of prompts to generate completions for. inputs: A list of inputs to generate completions for.
sampling_params: The sampling parameters for text generation. If sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters. None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt. When it is a single value, it is applied to every prompt.
When it is a list, the list must have the same length as the When it is a list, the list must have the same length as the
prompts and it is paired one by one with the prompt. prompts and it is paired one by one with the prompt.
prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs.
use_tqdm: Whether to use tqdm to display the progress bar. use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data.
Returns: Returns:
A list of `RequestOutput` objects containing the A list of `RequestOutput` objects containing the
generated completions in the same order as the input prompts. generated completions in the same order as the input prompts.
""" """
if prompt_token_ids is not None or multi_modal_data is not None:
inputs = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data,
)
else:
inputs = cast(
Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
prompts)
if sampling_params is None: if sampling_params is None:
# Use default sampling params. # Use default sampling params.
sampling_params = SamplingParams() sampling_params = SamplingParams()
requests_data = self._validate_and_prepare_requests( self._validate_and_add_requests(
prompts, inputs=inputs,
sampling_params, params=sampling_params,
prompt_token_ids, lora_request=lora_request,
lora_request,
multi_modal_data,
) )
# Add requests to the engine and run the engine outputs = self._run_engine(use_tqdm=use_tqdm)
for request_data in requests_data: return LLMEngine.validate_outputs(outputs, RequestOutput)
self._add_request(**request_data)
return self._run_engine(use_tqdm)
@overload # LEGACY: single (prompt + optional token ids)
def encode( def encode(
self, self,
prompts: Optional[Union[str, List[str]]] = None, prompts: str,
pooling_params: Optional[Union[PoolingParams, pooling_params: Optional[Union[PoolingParams,
List[PoolingParams]]] = None, Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[List[int]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
...
@overload # LEGACY: multi (prompt + optional token ids)
def encode(
self,
prompts: List[str],
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[List[List[int]]] = None, prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True, use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None, multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
...
@overload # LEGACY: single (token ids + optional prompt)
def encode(
self,
prompts: Optional[str] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
*,
prompt_token_ids: List[int],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
...
@overload # LEGACY: multi (token ids + optional prompt)
def encode(
self,
prompts: Optional[List[str]] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
*,
prompt_token_ids: List[List[int]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
...
@overload # LEGACY: single or multi token ids [pos-only]
def encode(
self,
prompts: None,
pooling_params: None,
prompt_token_ids: Union[List[int], List[List[int]]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
...
@overload
def encode(
self,
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
/, # We may enable `inputs` keyword after removing the old API
*,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
) -> List[EmbeddingRequestOutput]:
...
@deprecate_kwargs("prompts",
"prompt_token_ids",
"multi_modal_data",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter "
"instead.")
def encode(
self,
prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
Optional[Union[str, List[str]]]] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]: ) -> List[EmbeddingRequestOutput]:
"""Generates the completions for the input prompts. """Generates the completions for the input prompts.
@ -206,124 +398,133 @@ class LLM:
into a single list and pass it to this method. into a single list and pass it to this method.
Args: Args:
prompts: A list of prompts to generate completions for. inputs: The inputs to the LLM. You may pass a sequence of inputs for
batch inference. See :class:`~vllm.inputs.PromptStrictInputs`
for more details about the format of each input.
pooling_params: The pooling parameters for pooling. If None, we pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters. use the default pooling parameters.
prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs.
use_tqdm: Whether to use tqdm to display the progress bar. use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data.
Returns: Returns:
A list of `EmbeddingRequestOutput` objects containing the A list of `EmbeddingRequestOutput` objects containing the
generated embeddings in the same order as the input prompts. generated embeddings in the same order as the input prompts.
""" """
if prompt_token_ids is not None or multi_modal_data is not None:
inputs = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data,
)
else:
inputs = cast(
Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
prompts)
if pooling_params is None: if pooling_params is None:
# Use default pooling params. # Use default pooling params.
pooling_params = PoolingParams() pooling_params = PoolingParams()
requests_data = self._validate_and_prepare_requests( self._validate_and_add_requests(
prompts, inputs=inputs,
pooling_params, params=pooling_params,
prompt_token_ids, lora_request=lora_request,
lora_request,
multi_modal_data,
) )
# Add requests to the engine and run the engine outputs = self._run_engine(use_tqdm=use_tqdm)
for request_data in requests_data: return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput)
self._add_request(**request_data)
return self._run_engine(use_tqdm) # LEGACY
def _convert_v1_inputs(
def _validate_and_prepare_requests(
self, self,
prompts: Optional[Union[str, List[str]]], prompts: Optional[Union[str, List[str]]],
params: Union[Union[SamplingParams, PoolingParams], prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
List[Union[SamplingParams, multi_modal_data: Optional[MultiModalData],
PoolingParams]]], # Unified parameter ):
prompt_token_ids: Optional[List[List[int]]] = None, # skip_tokenizer_init is now checked in engine
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[dict]:
"""Validates and prepares request data for adding to the engine.
Ensures prompts and token IDs are consistent, and returns a list of if prompts is not None:
dictionaries with request data for further processing. prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
""" if prompt_token_ids is not None:
if prompts is None and prompt_token_ids is None: prompt_token_ids = [
raise ValueError("Either prompts or prompt_token_ids must be " p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
"provided.") ]
if self.llm_engine.model_config.skip_tokenizer_init \
and prompts is not None: num_requests = None
raise ValueError("prompts must be None if skip_tokenizer_init " if prompts is not None:
"is True") num_requests = len(prompts)
if isinstance(prompts, str): if prompt_token_ids is not None:
# Convert a single prompt to a list. if (num_requests is not None
prompts = [prompts] and num_requests != len(prompt_token_ids)):
if (prompts is not None and prompt_token_ids is not None
and len(prompts) != len(prompt_token_ids)):
raise ValueError("The lengths of prompts and prompt_token_ids " raise ValueError("The lengths of prompts and prompt_token_ids "
"must be the same.") "must be the same.")
if prompts is not None:
num_requests = len(prompts)
else:
assert prompt_token_ids is not None
num_requests = len(prompt_token_ids) num_requests = len(prompt_token_ids)
if num_requests is None:
raise ValueError("Either prompts or prompt_token_ids must be "
"provided.")
inputs: List[PromptInputs] = []
for i in range(num_requests):
if prompts is not None:
if prompt_token_ids is not None:
item = TextTokensPrompt(
prompt=prompts[i],
prompt_token_ids=prompt_token_ids[i])
else:
item = TextPrompt(prompt=prompts[i])
else:
if prompt_token_ids is not None:
item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
else:
raise AssertionError
if multi_modal_data is not None:
item["multi_modal_data"] = multi_modal_data
inputs.append(item)
return inputs
def _validate_and_add_requests(
self,
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]],
lora_request: Optional[LoRARequest],
) -> None:
if isinstance(inputs, (str, dict)):
# Convert a single prompt to a list.
inputs = [inputs]
num_requests = len(inputs)
if isinstance(params, list) and len(params) != num_requests: if isinstance(params, list) and len(params) != num_requests:
raise ValueError("The lengths of prompts and params " raise ValueError("The lengths of prompts and params "
"must be the same.") "must be the same.")
if multi_modal_data:
multi_modal_data.data = multi_modal_data.data.to(torch.float16)
# Add requests to the engine. # Add requests to the engine.
requests_data = [] for i, request_inputs in enumerate(inputs):
for i in range(num_requests): self._add_request(
prompt = prompts[i] if prompts is not None else None request_inputs,
token_ids = None if prompt_token_ids is None else prompt_token_ids[ params[i] if isinstance(params, Sequence) else params,
i] lora_request=lora_request,
)
multi_modal_item = MultiModalData(
type=multi_modal_data.type,
data=multi_modal_data.data[i].unsqueeze(0),
) if multi_modal_data else None
requests_data.append({
"prompt":
prompt,
"params":
params[i] if isinstance(params, list) else params,
"prompt_token_ids":
token_ids,
"lora_request":
lora_request,
"multi_modal_data":
multi_modal_item,
})
return requests_data
def _add_request( def _add_request(
self, self,
prompt: Optional[str], inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]],
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None: ) -> None:
request_id = str(next(self.request_counter)) request_id = str(next(self.request_counter))
self.llm_engine.add_request(request_id, self.llm_engine.add_request(request_id,
prompt, inputs,
params, params,
prompt_token_ids, lora_request=lora_request)
lora_request=lora_request,
multi_modal_data=multi_modal_data)
def _run_engine( def _run_engine(
self, use_tqdm: bool self, *, use_tqdm: bool
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
# Initialize tqdm. # Initialize tqdm.
if use_tqdm: if use_tqdm:
@ -355,5 +556,4 @@ class LLM:
# Sort the outputs by request ID. # Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than # This is necessary because some requests may be finished earlier than
# its previous requests. # its previous requests.
outputs = sorted(outputs, key=lambda x: int(x.request_id)) return sorted(outputs, key=lambda x: int(x.request_id))
return outputs

View File

@ -176,9 +176,15 @@ class OpenAIServingChat(OpenAIServing):
except ValueError as e: except ValueError as e:
return self.create_error_response(str(e)) return self.create_error_response(str(e))
result_generator = self.engine.generate(prompt_text, sampling_params, result_generator = self.engine.generate(
request_id, prompt_ids, {
lora_request) "prompt": prompt_text,
"prompt_token_ids": prompt_ids
},
sampling_params,
request_id,
lora_request,
)
# Streaming response # Streaming response
if request.stream: if request.stream:
return self.chat_completion_stream_generator( return self.chat_completion_stream_generator(

View File

@ -119,12 +119,17 @@ class OpenAIServingCompletion(OpenAIServing):
truncate_prompt_tokens) truncate_prompt_tokens)
prompt_ids, prompt_text = prompt_formats prompt_ids, prompt_text = prompt_formats
generators.append( generator = self.engine.generate(
self.engine.generate(prompt_text, {
"prompt": prompt_text,
"prompt_token_ids": prompt_ids
},
sampling_params, sampling_params,
f"{request_id}-{i}", f"{request_id}-{i}",
prompt_token_ids=prompt_ids, lora_request=lora_request,
lora_request=lora_request)) )
generators.append(generator)
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))

View File

@ -1,5 +1,5 @@
import time import time
from typing import AsyncIterator, List, Tuple from typing import AsyncIterator, List, Optional, Tuple
from fastapi import Request from fastapi import Request
@ -100,11 +100,16 @@ class OpenAIServingEmbedding(OpenAIServing):
prompt_ids, prompt_text = prompt_formats prompt_ids, prompt_text = prompt_formats
generators.append( generator = self.engine.encode(
self.engine.generate(prompt_text, {
"prompt": prompt_text,
"prompt_token_ids": prompt_ids
},
pooling_params, pooling_params,
f"{request_id}-{i}", f"{request_id}-{i}",
prompt_token_ids=prompt_ids)) )
generators.append(generator)
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))
@ -113,7 +118,9 @@ class OpenAIServingEmbedding(OpenAIServing):
int, EmbeddingRequestOutput]] = merge_async_iterators(*generators) int, EmbeddingRequestOutput]] = merge_async_iterators(*generators)
# Non-streaming response # Non-streaming response
final_res_batch: EmbeddingRequestOutput = [None] * len(prompts) final_res_batch: List[Optional[EmbeddingRequestOutput]]
final_res_batch = [None] * len(prompts)
try:
async for i, res in result_generator: async for i, res in result_generator:
if await raw_request.is_disconnected(): if await raw_request.is_disconnected():
# Abort the request if the client disconnects. # Abort the request if the client disconnects.
@ -123,6 +130,9 @@ class OpenAIServingEmbedding(OpenAIServing):
final_res_batch[i] = res final_res_batch[i] = res
response = request_output_to_embedding_response( response = request_output_to_embedding_response(
final_res_batch, request_id, created_time, model_name) final_res_batch, request_id, created_time, model_name)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return response return response

View File

@ -143,7 +143,8 @@ class OpenAIServing:
return json_str return json_str
async def _check_model( async def _check_model(
self, request: Union[CompletionRequest, ChatCompletionRequest] self, request: Union[CompletionRequest, ChatCompletionRequest,
EmbeddingRequest]
) -> Optional[ErrorResponse]: ) -> Optional[ErrorResponse]:
if request.model in self.served_model_names: if request.model in self.served_model_names:
return None return None
@ -155,7 +156,8 @@ class OpenAIServing:
status_code=HTTPStatus.NOT_FOUND) status_code=HTTPStatus.NOT_FOUND)
def _maybe_get_lora( def _maybe_get_lora(
self, request: Union[CompletionRequest, ChatCompletionRequest] self, request: Union[CompletionRequest, ChatCompletionRequest,
EmbeddingRequest]
) -> Optional[LoRARequest]: ) -> Optional[LoRARequest]:
if request.model in self.served_model_names: if request.model in self.served_model_names:
return None return None

130
vllm/inputs.py Normal file
View File

@ -0,0 +1,130 @@
from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence,
TypedDict, Union, cast, overload)
from typing_extensions import NotRequired
if TYPE_CHECKING:
from vllm.sequence import MultiModalData
class ParsedText(TypedDict):
content: str
is_tokens: Literal[False]
class ParsedTokens(TypedDict):
content: List[int]
is_tokens: Literal[True]
# https://github.com/vllm-project/vllm/pull/4028
@overload
def parse_and_batch_prompt(
prompt: Union[str, List[str]]) -> Sequence[ParsedText]:
...
@overload
def parse_and_batch_prompt(
prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]:
...
def parse_and_batch_prompt(
prompt: Union[str, List[str], List[int], List[List[int]]],
) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
if isinstance(prompt, str):
# case 1: a string
return [ParsedText(content=prompt, is_tokens=False)]
if isinstance(prompt, list):
if len(prompt) == 0:
raise ValueError("please provide at least one prompt")
if isinstance(prompt[0], str):
# case 2: array of strings
return [
ParsedText(content=elem, is_tokens=False)
for elem in cast(List[str], prompt)
]
if isinstance(prompt[0], int):
# case 3: array of tokens
elem = cast(List[int], prompt)
return [ParsedTokens(content=elem, is_tokens=True)]
if isinstance(prompt[0], list):
if len(prompt[0]) == 0:
raise ValueError("please provide at least one prompt")
if isinstance(prompt[0][0], int):
# case 4: array of token arrays
return [
ParsedTokens(content=elem, is_tokens=True)
for elem in cast(List[List[int]], prompt)
]
raise ValueError("prompt must be a string, array of strings, "
"array of tokens, or array of token arrays")
class TextPrompt(TypedDict):
"""Schema for a text prompt."""
prompt: str
"""The input text to be tokenized before passing to the model."""
multi_modal_data: NotRequired["MultiModalData"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
class TokensPrompt(TypedDict):
"""Schema for a tokenized prompt."""
prompt_token_ids: List[int]
"""A list of token IDs to pass to the model."""
multi_modal_data: NotRequired["MultiModalData"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
class TextTokensPrompt(TypedDict):
"""It is assumed that :attr:`prompt` is consistent with
:attr:`prompt_token_ids`. This is currently used in
:class:`AsyncLLMEngine` for logging both the text and token IDs."""
prompt: str
"""The prompt text."""
prompt_token_ids: List[int]
"""The token IDs of the prompt. If None, we use the
tokenizer to convert the prompts to token IDs."""
multi_modal_data: NotRequired["MultiModalData"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
PromptStrictInputs = Union[str, TextPrompt, TokensPrompt]
"""
The inputs to the LLM, which can take one of the following forms:
- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
"""
PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt]
"""Same as :const:`PromptStrictInputs` but additionally accepts
:class:`TextTokensPrompt`."""
class LLMInputs(TypedDict):
prompt_token_ids: List[int]
prompt: Optional[str]
multi_modal_data: Optional["MultiModalData"]

View File

@ -1,4 +1,5 @@
import time import time
from dataclasses import dataclass
from typing import List, Optional, Union from typing import List, Optional, Union
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@ -6,6 +7,7 @@ from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
SequenceGroup, SequenceStatus) SequenceGroup, SequenceStatus)
@dataclass
class CompletionOutput: class CompletionOutput:
"""The output data of one completion output of a request. """The output data of one completion output of a request.
@ -24,25 +26,14 @@ class CompletionOutput:
lora_request: The LoRA request that was used to generate the output. lora_request: The LoRA request that was used to generate the output.
""" """
def __init__( index: int
self, text: str
index: int, token_ids: List[int]
text: str, cumulative_logprob: float
token_ids: List[int], logprobs: Optional[SampleLogprobs]
cumulative_logprob: float, finish_reason: Optional[str] = None
logprobs: Optional[SampleLogprobs], stop_reason: Union[int, str, None] = None
finish_reason: Optional[str] = None, lora_request: Optional[LoRARequest] = None
stop_reason: Union[int, str, None] = None,
lora_request: Optional[LoRARequest] = None,
) -> None:
self.index = index
self.text = text
self.token_ids = token_ids
self.cumulative_logprob = cumulative_logprob
self.logprobs = logprobs
self.finish_reason = finish_reason
self.stop_reason = stop_reason
self.lora_request = lora_request
def finished(self) -> bool: def finished(self) -> bool:
return self.finish_reason is not None return self.finish_reason is not None
@ -57,6 +48,7 @@ class CompletionOutput:
f"stop_reason={self.stop_reason})") f"stop_reason={self.stop_reason})")
@dataclass
class EmbeddingOutput: class EmbeddingOutput:
"""The output data of one completion output of a request. """The output data of one completion output of a request.
@ -65,15 +57,11 @@ class EmbeddingOutput:
length of vector depends on the model as listed in the embedding guide. length of vector depends on the model as listed in the embedding guide.
""" """
def __init__( embedding: List[float]
self,
embedding: List[float],
) -> None:
self.embedding = embedding
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"EmbeddingOutput(" return (f"EmbeddingOutput("
f"embedding={len(self.embedding)}") f"embedding={len(self.embedding)})")
class RequestOutput: class RequestOutput:
@ -93,7 +81,7 @@ class RequestOutput:
def __init__( def __init__(
self, self,
request_id: str, request_id: str,
prompt: str, prompt: Optional[str],
prompt_token_ids: List[int], prompt_token_ids: List[int],
prompt_logprobs: Optional[PromptLogprobs], prompt_logprobs: Optional[PromptLogprobs],
outputs: List[CompletionOutput], outputs: List[CompletionOutput],
@ -183,7 +171,7 @@ class EmbeddingRequestOutput:
finished (bool): A flag indicating whether the embedding is completed. finished (bool): A flag indicating whether the embedding is completed.
""" """
def __init__(self, request_id: str, outputs: 'EmbeddingOutput', def __init__(self, request_id: str, outputs: "EmbeddingOutput",
prompt_token_ids: List[int], finished: bool): prompt_token_ids: List[int], finished: bool):
self.request_id = request_id self.request_id = request_id
self.prompt_token_ids = prompt_token_ids self.prompt_token_ids = prompt_token_ids

View File

@ -6,6 +6,7 @@ from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from vllm.block import LogicalTokenBlock from vllm.block import LogicalTokenBlock
from vllm.inputs import LLMInputs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
@ -210,8 +211,7 @@ class Sequence:
Args: Args:
seq_id: The ID of the sequence. seq_id: The ID of the sequence.
prompt: The prompt of the sequence. inputs: The inputs of the sequence.
prompt_token_ids: The token IDs of the prompt.
block_size: The block size of the sequence. Should be the same as the block_size: The block size of the sequence. Should be the same as the
block size used by the block manager and cache engine. block size used by the block manager and cache engine.
lora_request: LoRA request. lora_request: LoRA request.
@ -220,25 +220,24 @@ class Sequence:
def __init__( def __init__(
self, self,
seq_id: int, seq_id: int,
prompt: str, inputs: LLMInputs,
prompt_token_ids: List[int],
block_size: int, block_size: int,
eos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
) -> None: ) -> None:
self.seq_id = seq_id self.seq_id = seq_id
self.prompt = prompt self.inputs = inputs
self.block_size = block_size self.block_size = block_size
self.eos_token_id = eos_token_id self.eos_token_id = eos_token_id
self.lora_request = lora_request self.lora_request = lora_request
self.data: SequenceData = SequenceData(prompt_token_ids) self.data = SequenceData(self.prompt_token_ids)
self.output_logprobs: SampleLogprobs = [] self.output_logprobs: SampleLogprobs = []
self.output_text = "" self.output_text = ""
self.logical_token_blocks: List[LogicalTokenBlock] = [] self.logical_token_blocks: List[LogicalTokenBlock] = []
# Initialize the logical token blocks with the prompt token ids. # Initialize the logical token blocks with the prompt token ids.
self._append_tokens_to_blocks(prompt_token_ids) self._append_tokens_to_blocks(self.prompt_token_ids)
self.status = SequenceStatus.WAITING self.status = SequenceStatus.WAITING
self.stop_reason: Union[int, str, None] = None self.stop_reason: Union[int, str, None] = None
@ -248,6 +247,18 @@ class Sequence:
# Input + output tokens # Input + output tokens
self.tokens: Optional[List[str]] = None self.tokens: Optional[List[str]] = None
@property
def prompt(self) -> Optional[str]:
return self.inputs["prompt"]
@property
def prompt_token_ids(self) -> List[int]:
return self.inputs["prompt_token_ids"]
@property
def multi_modal_data(self) -> Optional["MultiModalData"]:
return self.inputs["multi_modal_data"]
@property @property
def lora_int_id(self) -> int: def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0 return self.lora_request.lora_int_id if self.lora_request else 0
@ -415,7 +426,6 @@ class SequenceGroup:
sampling_params: The sampling parameters used to generate the outputs. sampling_params: The sampling parameters used to generate the outputs.
arrival_time: The arrival time of the request. arrival_time: The arrival time of the request.
lora_request: LoRA request. lora_request: LoRA request.
multi_modal_data: Multi modal data associated with the request.
embeddings: The embeddings vectors of the prompt of the sequence group embeddings: The embeddings vectors of the prompt of the sequence group
for an embedding model. for an embedding model.
pooling_params: The pooling parameters used to generate the pooling pooling_params: The pooling parameters used to generate the pooling
@ -429,7 +439,6 @@ class SequenceGroup:
arrival_time: float, arrival_time: float,
sampling_params: Optional[SamplingParams] = None, sampling_params: Optional[SamplingParams] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
embeddings: Optional[List[float]] = None, embeddings: Optional[List[float]] = None,
pooling_params: Optional[PoolingParams] = None, pooling_params: Optional[PoolingParams] = None,
) -> None: ) -> None:
@ -444,12 +453,11 @@ class SequenceGroup:
self.lora_request = lora_request self.lora_request = lora_request
self.prompt_logprobs: Optional[PromptLogprobs] = None self.prompt_logprobs: Optional[PromptLogprobs] = None
self.state = SequenceGroupState() self.state = SequenceGroupState()
self.multi_modal_data = multi_modal_data
self.embeddings = embeddings self.embeddings = embeddings
self.pooling_params = pooling_params self.pooling_params = pooling_params
@property @property
def prompt(self) -> str: def prompt(self) -> Optional[str]:
# All sequences in the group should have the same prompt. # All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence. # We use the prompt of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).prompt return next(iter(self.seqs_dict.values())).prompt
@ -458,7 +466,13 @@ class SequenceGroup:
def prompt_token_ids(self) -> List[int]: def prompt_token_ids(self) -> List[int]:
# All sequences in the group should have the same prompt. # All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence. # We use the prompt of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).data.prompt_token_ids return next(iter(self.seqs_dict.values())).prompt_token_ids
@property
def multi_modal_data(self) -> Optional[MultiModalData]:
# All sequences in the group should have the same multi-modal data.
# We use the multi-modal data of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).multi_modal_data
@property @property
def lora_int_id(self) -> int: def lora_int_id(self) -> int:

View File

@ -11,7 +11,7 @@ import threading
import uuid import uuid
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from functools import lru_cache, partial from functools import lru_cache, partial, wraps
from platform import uname from platform import uname
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic, from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
Hashable, List, Optional, OrderedDict, Tuple, TypeVar, Hashable, List, Optional, OrderedDict, Tuple, TypeVar,
@ -658,3 +658,44 @@ def enable_trace_function_call_for_thread() -> None:
filename) filename)
os.makedirs(os.path.dirname(log_path), exist_ok=True) os.makedirs(os.path.dirname(log_path), exist_ok=True)
enable_trace_function_call(log_path) enable_trace_function_call(log_path)
def identity(value: T) -> T:
return value
F = TypeVar('F', bound=Callable[..., Any])
def deprecate_kwargs(
*kws: str,
is_deprecated: Union[bool, Callable[[], bool]] = True,
additional_message: Optional[str] = None) -> Callable[[F], F]:
deprecated_kws = set(kws)
if not callable(is_deprecated):
is_deprecated = partial(identity, is_deprecated)
def wrapper(fn: F) -> F:
@wraps(fn)
def inner(*args, **kwargs):
if is_deprecated():
deprecated_kwargs = kwargs.keys() & deprecated_kws
if deprecated_kwargs:
msg = (
f"The keyword arguments {deprecated_kwargs} are "
"deprecated and will be removed in a future update.")
if additional_message is not None:
msg += f" {additional_message}"
warnings.warn(
DeprecationWarning(msg),
stacklevel=3, # The inner function takes up one level
)
return fn(*args, **kwargs)
return inner # type: ignore
return wrapper