[Core] Consolidate prompt arguments to LLM engines (#4328)

Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Cyrus Leung 2024-05-29 04:29:31 +08:00 committed by GitHub
parent 290f4ada2b
commit 5ae5ed1e60
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
43 changed files with 1407 additions and 442 deletions

View File

@ -63,9 +63,9 @@ steps:
mirror_hardwares: [amd]
commands:
# these tests have to be separated, because each one will allocate all posible GPU memory
- pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py
- pytest -v -s entrypoints/test_server_oot_registration.py
- pytest -v -s test_inputs.py
- pytest -v -s entrypoints -m llm
- pytest -v -s entrypoints -m openai
- label: Examples Test
working_dir: "/vllm-workspace/examples"
@ -110,6 +110,9 @@ steps:
mirror_hardwares: [amd]
command: pytest -v -s test_logits_processor.py
- label: Utils Test
command: pytest -v -s test_utils.py
- label: Worker Test
mirror_hardwares: [amd]
command: pytest -v -s worker

View File

@ -3,13 +3,14 @@ import argparse
import json
import time
from pathlib import Path
from typing import Optional
from typing import List, Optional
import numpy as np
import torch
from tqdm import tqdm
from vllm import LLM, SamplingParams
from vllm.inputs import PromptStrictInputs
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
@ -48,7 +49,9 @@ def main(args: argparse.Namespace):
dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size,
args.input_len))
dummy_prompt_token_ids = dummy_prompt_token_ids.tolist()
dummy_inputs: List[PromptStrictInputs] = [{
"prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()]
def run_to_completion(profile_dir: Optional[str] = None):
if profile_dir:
@ -59,13 +62,13 @@ def main(args: argparse.Namespace):
],
on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir))) as p:
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
llm.generate(dummy_inputs,
sampling_params=sampling_params,
use_tqdm=False)
print(p.key_averages())
else:
start_time = time.perf_counter()
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
llm.generate(dummy_inputs,
sampling_params=sampling_params,
use_tqdm=False)
end_time = time.perf_counter()

View File

@ -1,5 +1,5 @@
LLM Class
==========
=========
.. autoclass:: vllm.LLM
:members:

View File

@ -0,0 +1,14 @@
LLM Inputs
==========
.. autodata:: vllm.inputs.PromptStrictInputs
.. autoclass:: vllm.inputs.TextPrompt
:show-inheritance:
:members:
:member-order: bysource
.. autoclass:: vllm.inputs.TokensPrompt
:show-inheritance:
:members:
:member-order: bysource

View File

@ -0,0 +1,8 @@
Offline Inference
=================================
.. toctree::
:maxdepth: 1
llm
llm_inputs

View File

@ -68,13 +68,6 @@ Documentation
getting_started/quickstart
getting_started/examples/examples_index
.. toctree::
:maxdepth: 1
:caption: Offline Inference
offline_inference/llm
offline_inference/sampling_params
.. toctree::
:maxdepth: 1
:caption: Serving
@ -108,7 +101,9 @@ Documentation
.. toctree::
:maxdepth: 2
:caption: Developer Documentation
dev/sampling_params
dev/offline_inference/offline_index
dev/engine/engine_index
dev/kernel/paged_attention
dev/dockerfile/dockerfile

View File

@ -48,7 +48,7 @@ completion = client.chat.completions.create(
```
### Extra Parameters for Chat API
The following [sampling parameters (click through to see documentation)](../offline_inference/sampling_params.rst) are supported.
The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python
@ -65,7 +65,7 @@ The following extra parameters are supported:
```
### Extra Parameters for Completions API
The following [sampling parameters (click through to see documentation)](../offline_inference/sampling_params.rst) are supported.
The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
:language: python

View File

@ -23,11 +23,15 @@ def run_llava_pixel_values():
"\nUSER: What is the content of this image?\nASSISTANT:")
# This should be provided by another online or offline component.
images = torch.load("images/stop_sign_pixel_values.pt")
image = torch.load("images/stop_sign_pixel_values.pt")
outputs = llm.generate({
"prompt":
prompt,
"multi_modal_data":
MultiModalData(type=MultiModalData.Type.IMAGE, data=image),
})
outputs = llm.generate(prompt,
multi_modal_data=MultiModalData(
type=MultiModalData.Type.IMAGE, data=images))
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
@ -46,11 +50,14 @@ def run_llava_image_features():
"\nUSER: What is the content of this image?\nASSISTANT:")
# This should be provided by another online or offline component.
images = torch.load("images/stop_sign_image_features.pt")
image = torch.load("images/stop_sign_image_features.pt")
outputs = llm.generate(prompt,
multi_modal_data=MultiModalData(
type=MultiModalData.Type.IMAGE, data=images))
outputs = llm.generate({
"prompt":
prompt,
"multi_modal_data":
MultiModalData(type=MultiModalData.Type.IMAGE, data=image),
})
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)

View File

@ -65,3 +65,10 @@ skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build"
[tool.isort]
use_parentheses = true
skip_gitignore = true
[tool.pytest.ini_options]
markers = [
"skip_global_cleanup",
"llm: run tests for vLLM API only",
"openai: run tests for OpenAI API only",
]

View File

@ -25,7 +25,7 @@ class MockEngine:
return [RequestOutput(
request_id=self.request_id)] if self.request_id else []
async def encode_request_async(self, *args, **kwargs):
async def process_model_inputs_async(self, *args, **kwargs):
pass
def generate(self, request_id):

View File

@ -29,7 +29,7 @@ def server():
ray.shutdown()
@pytest.fixture(scope="session")
@pytest.fixture(scope="module")
def client():
client = openai.AsyncOpenAI(
base_url="http://localhost:8000/v1",

View File

@ -12,6 +12,7 @@ from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer,
from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
from vllm.distributed import destroy_model_parallel
from vllm.inputs import PromptInputs
from vllm.logger import init_logger
from vllm.sequence import MultiModalData
@ -402,12 +403,22 @@ class VllmRunner:
) -> List[Tuple[List[int], str]]:
if images is not None:
assert len(prompts) == images.shape[0]
req_outputs = self.model.generate(
prompts,
sampling_params=sampling_params,
multi_modal_data=MultiModalData(type=MultiModalData.Type.IMAGE,
data=images)
if images is not None else None)
prompt_inputs: List[PromptInputs] = []
for i, prompt in enumerate(prompts):
image = None if images is None else images[i:i + 1]
mm_data = None if image is None else MultiModalData(
type=MultiModalData.Type.IMAGE,
data=image,
)
prompt_inputs.append({
"prompt": prompt,
"multi_modal_data": mm_data,
})
req_outputs = self.model.generate(prompt_inputs,
sampling_params=sampling_params)
outputs = []
for req_output in req_outputs:
prompt_str = req_output.prompt

View File

@ -133,8 +133,11 @@ def test_append_slot_cow():
# Allocate prompt to gpu block. There is one slot left in the block.
prompt = Sequence(seq_id=1,
prompt="one two three",
prompt_token_ids=[1, 2, 3],
inputs={
"prompt": "one two three",
"prompt_token_ids": [1, 2, 3],
"multi_modal_data": None
},
block_size=block_size)
# Fork the sequence, such that a COW will be required when we append a new
@ -304,7 +307,13 @@ def test_sliding_window_multi_seq():
assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks
parent = Sequence(1, "one two three", [0, 1, 2], block_size)
parent = Sequence(seq_id=1,
inputs={
"prompt": "one two three",
"prompt_token_ids": [0, 1, 2],
"multi_modal_data": None
},
block_size=block_size)
seq_group = SequenceGroup(request_id="1",
seqs=[parent],
arrival_time=time.time(),

View File

@ -21,7 +21,13 @@ def create_dummy_prompt(
# and prompt "0 ... block_size".
prompt_tokens = list(range(prompt_length))
prompt_str = " ".join([str(t) for t in prompt_tokens])
prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size)
prompt = Sequence(int(request_id),
inputs={
"prompt": prompt_str,
"prompt_token_ids": prompt_tokens,
"multi_modal_data": None,
},
block_size=block_size)
seq_group = SequenceGroup(request_id=request_id,
seqs=[prompt],
arrival_time=time.time(),
@ -51,8 +57,11 @@ def create_seq_group(
for seq_id_offset, output_len in enumerate(seq_output_lens):
seq = Sequence(
seq_id=seq_id_start + seq_id_offset,
prompt="",
prompt_token_ids=prompt_token_ids,
inputs={
"prompt": "",
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": None,
},
block_size=16,
)

View File

@ -14,7 +14,7 @@ def test_skip_tokenizer_initialization(model: str):
with pytest.raises(ValueError) as err:
llm.generate("abc", sampling_params)
assert "prompts must be None if" in str(err.value)
outputs = llm.generate(prompt_token_ids=[[1, 2, 3]],
outputs = llm.generate({"prompt_token_ids": [1, 2, 3]},
sampling_params=sampling_params)
assert len(outputs) > 0
completions = outputs[0].outputs

View File

@ -1,11 +1,15 @@
import asyncio
from dataclasses import dataclass
import pytest
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
MODEL_NAME = "openai-community/gpt2"
CHAT_TEMPLATE = "Dummy chat template for testing {}"
pytestmark = pytest.mark.openai
@dataclass
class MockModelConfig:

View File

@ -52,6 +52,8 @@ TEST_SCHEMA = {
TEST_REGEX = (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
pytestmark = pytest.mark.openai
def test_guided_logits_processors():
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""

View File

@ -0,0 +1,144 @@
import weakref
from typing import List
import pytest
from vllm import LLM, EmbeddingRequestOutput, PoolingParams
from ..conftest import cleanup
MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
PROMPTS = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
TOKEN_IDS = [
# Using ID={0, 1, 2, 3} results in NaN values,
# so we add this offset of 1000
[1000],
[1000, 1001],
[1000, 1002, 1001],
[1000, 1003, 1001, 1002],
]
pytestmark = pytest.mark.llm
@pytest.fixture(scope="module")
def llm():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(model=MODEL_NAME,
max_num_batched_tokens=32768,
tensor_parallel_size=1,
gpu_memory_utilization=0.75,
enforce_eager=True)
with llm.deprecate_legacy_api():
yield weakref.proxy(llm)
del llm
cleanup()
def assert_outputs_equal(o1: List[EmbeddingRequestOutput],
o2: List[EmbeddingRequestOutput]):
assert [o.outputs for o in o1] == [o.outputs for o in o2]
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt', PROMPTS)
def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
pooling_params = PoolingParams()
with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.encode(prompts=prompt, pooling_params=pooling_params)
v2_output = llm.encode(prompt, pooling_params=pooling_params)
assert_outputs_equal(v1_output, v2_output)
v2_output = llm.encode({"prompt": prompt}, pooling_params=pooling_params)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
prompt_token_ids):
pooling_params = PoolingParams()
with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
v1_output = llm.encode(prompt_token_ids=prompt_token_ids,
pooling_params=pooling_params)
v2_output = llm.encode({"prompt_token_ids": prompt_token_ids},
pooling_params=pooling_params)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
pooling_params = PoolingParams()
with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.encode(prompts=PROMPTS, pooling_params=pooling_params)
v2_output = llm.encode(PROMPTS, pooling_params=pooling_params)
assert_outputs_equal(v1_output, v2_output)
v2_output = llm.encode(
[{
"prompt": p
} for p in PROMPTS],
pooling_params=pooling_params,
)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
pooling_params = PoolingParams()
with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
v1_output = llm.encode(prompt_token_ids=TOKEN_IDS,
pooling_params=pooling_params)
v2_output = llm.encode(
[{
"prompt_token_ids": p
} for p in TOKEN_IDS],
pooling_params=pooling_params,
)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
def test_multiple_pooling_params(llm: LLM):
pooling_params = [
PoolingParams(),
PoolingParams(),
PoolingParams(),
PoolingParams(),
]
# Multiple PoolingParams should be matched with each prompt
outputs = llm.encode(PROMPTS, pooling_params=pooling_params)
assert len(PROMPTS) == len(outputs)
# Exception raised, if the size of params does not match the size of prompts
with pytest.raises(ValueError):
outputs = llm.encode(PROMPTS, pooling_params=pooling_params[:3])
# Single PoolingParams should be applied to every prompt
single_pooling_params = PoolingParams()
outputs = llm.encode(PROMPTS, pooling_params=single_pooling_params)
assert len(PROMPTS) == len(outputs)
# pooling_params is None, default params should be applied
outputs = llm.encode(PROMPTS, pooling_params=None)
assert len(PROMPTS) == len(outputs)

View File

@ -1,21 +1,124 @@
import weakref
from typing import List
import pytest
from vllm import LLM, SamplingParams
from vllm import LLM, RequestOutput, SamplingParams
from ..conftest import cleanup
MODEL_NAME = "facebook/opt-125m"
PROMPTS = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
TOKEN_IDS = [
[0],
[0, 1],
[0, 2, 1],
[0, 3, 1, 2],
]
pytestmark = pytest.mark.llm
def test_multiple_sampling_params():
llm = LLM(model="facebook/opt-125m",
@pytest.fixture(scope="module")
def llm():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(model=MODEL_NAME,
max_num_batched_tokens=4096,
tensor_parallel_size=1)
tensor_parallel_size=1,
gpu_memory_utilization=0.10,
enforce_eager=True)
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
with llm.deprecate_legacy_api():
yield weakref.proxy(llm)
del llm
cleanup()
def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]):
assert [o.outputs for o in o1] == [o.outputs for o in o2]
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt', PROMPTS)
def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.generate(prompts=prompt,
sampling_params=sampling_params)
v2_output = llm.generate(prompt, sampling_params=sampling_params)
assert_outputs_equal(v1_output, v2_output)
v2_output = llm.generate({"prompt": prompt},
sampling_params=sampling_params)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
prompt_token_ids):
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
v1_output = llm.generate(prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params)
v2_output = llm.generate({"prompt_token_ids": prompt_token_ids},
sampling_params=sampling_params)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.generate(prompts=PROMPTS,
sampling_params=sampling_params)
v2_output = llm.generate(PROMPTS, sampling_params=sampling_params)
assert_outputs_equal(v1_output, v2_output)
v2_output = llm.generate(
[{
"prompt": p
} for p in PROMPTS],
sampling_params=sampling_params,
)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
v1_output = llm.generate(prompt_token_ids=TOKEN_IDS,
sampling_params=sampling_params)
v2_output = llm.generate(
[{
"prompt_token_ids": p
} for p in TOKEN_IDS],
sampling_params=sampling_params,
)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
def test_multiple_sampling_params(llm: LLM):
sampling_params = [
SamplingParams(temperature=0.01, top_p=0.95),
SamplingParams(temperature=0.3, top_p=0.95),
@ -24,18 +127,18 @@ def test_multiple_sampling_params():
]
# Multiple SamplingParams should be matched with each prompt
outputs = llm.generate(prompts, sampling_params=sampling_params)
assert len(prompts) == len(outputs)
outputs = llm.generate(PROMPTS, sampling_params=sampling_params)
assert len(PROMPTS) == len(outputs)
# Exception raised, if the size of params does not match the size of prompts
with pytest.raises(ValueError):
outputs = llm.generate(prompts, sampling_params=sampling_params[:3])
outputs = llm.generate(PROMPTS, sampling_params=sampling_params[:3])
# Single SamplingParams should be applied to every prompt
single_sampling_params = SamplingParams(temperature=0.3, top_p=0.95)
outputs = llm.generate(prompts, sampling_params=single_sampling_params)
assert len(prompts) == len(outputs)
outputs = llm.generate(PROMPTS, sampling_params=single_sampling_params)
assert len(PROMPTS) == len(outputs)
# sampling_params is None, default params should be applied
outputs = llm.generate(prompts, sampling_params=None)
assert len(prompts) == len(outputs)
outputs = llm.generate(PROMPTS, sampling_params=None)
assert len(PROMPTS) == len(outputs)

View File

@ -71,7 +71,7 @@ TEST_CHOICE = [
"Swift", "Kotlin"
]
pytestmark = pytest.mark.asyncio
pytestmark = pytest.mark.openai
@pytest.fixture(scope="session")
@ -91,6 +91,8 @@ def server(zephyr_lora_files):
"--max-model-len",
"8192",
"--enforce-eager",
"--gpu-memory-utilization",
"0.75",
# lora config below
"--enable-lora",
"--lora-modules",
@ -118,9 +120,11 @@ def embedding_server(zephyr_lora_files):
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"--enforce-eager",
"--gpu-memory-utilization",
"0.75",
"--max-model-len",
"8192",
"--enforce-eager",
])
ray.get(server_runner.ready.remote())
yield server_runner
@ -136,6 +140,7 @@ def client():
yield client
@pytest.mark.asyncio
async def test_check_models(server, client: openai.AsyncOpenAI):
models = await client.models.list()
models = models.data
@ -147,6 +152,7 @@ async def test_check_models(server, client: openai.AsyncOpenAI):
assert lora_models[1].id == "zephyr-lora2"
@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras
"model_name",
@ -178,6 +184,7 @@ async def test_single_completion(server, client: openai.AsyncOpenAI,
completion.choices[0].text) >= 5
@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras
"model_name",
@ -199,6 +206,7 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI,
assert choice.logprobs.top_logprobs is None
@pytest.mark.asyncio
@pytest.mark.parametrize(
# just test 1 lora hereafter
"model_name",
@ -243,6 +251,7 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
assert message.content is not None and len(message.content) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
model_name: str):
@ -298,6 +307,7 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
assert message.content is not None and len(message.content) >= 0
@pytest.mark.asyncio
@pytest.mark.parametrize(
# just test 1 lora hereafter
"model_name",
@ -335,6 +345,7 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI,
assert "".join(chunks) == single_output
@pytest.mark.asyncio
@pytest.mark.parametrize(
# just test 1 lora hereafter
"model_name",
@ -385,6 +396,7 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI,
assert "".join(chunks) == output
@pytest.mark.asyncio
@pytest.mark.parametrize(
# just test 1 lora hereafter
"model_name",
@ -438,6 +450,7 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI,
assert texts[0] == texts[1]
@pytest.mark.asyncio
async def test_logits_bias(server, client: openai.AsyncOpenAI):
prompt = "Hello, my name is"
max_tokens = 5
@ -485,6 +498,7 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
assert first_response != completion.choices[0].text
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_json_completion(server, client: openai.AsyncOpenAI,
@ -507,6 +521,7 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI,
jsonschema.validate(instance=output_json, schema=TEST_SCHEMA)
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_json_chat(server, client: openai.AsyncOpenAI,
@ -553,6 +568,7 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI,
assert json1["age"] != json2["age"]
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_regex_completion(server, client: openai.AsyncOpenAI,
@ -573,6 +589,7 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI,
assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_regex_chat(server, client: openai.AsyncOpenAI,
@ -610,6 +627,7 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI,
assert ip1 != ip2
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_choice_completion(server, client: openai.AsyncOpenAI,
@ -629,6 +647,7 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI,
assert completion.choices[i].text in TEST_CHOICE
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_choice_chat(server, client: openai.AsyncOpenAI,
@ -667,6 +686,7 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI,
assert choice1 != choice2
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI,
@ -702,6 +722,7 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI,
extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA))
@pytest.mark.asyncio
@pytest.mark.parametrize("guided_decoding_backend",
["outlines", "lm-format-enforcer"])
async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
@ -732,6 +753,7 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
for token, logprob in token_dict.items())
@pytest.mark.asyncio
async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
for _ in range(2):
resp = await client.chat.completions.create(
@ -749,6 +771,7 @@ async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
assert loaded == {"result": 2}, loaded
@pytest.mark.asyncio
async def test_extra_fields(server, client: openai.AsyncOpenAI):
with pytest.raises(BadRequestError) as exc_info:
await client.chat.completions.create(
@ -764,6 +787,7 @@ async def test_extra_fields(server, client: openai.AsyncOpenAI):
assert "extra_forbidden" in exc_info.value.message
@pytest.mark.asyncio
async def test_complex_message_content(server, client: openai.AsyncOpenAI):
resp = await client.chat.completions.create(
model=MODEL_NAME,
@ -783,6 +807,7 @@ async def test_complex_message_content(server, client: openai.AsyncOpenAI):
assert content == "2"
@pytest.mark.asyncio
async def test_custom_role(server, client: openai.AsyncOpenAI):
# Not sure how the model handles custom roles so we just check that
# both string and complex message content are handled in the same way
@ -813,6 +838,7 @@ async def test_custom_role(server, client: openai.AsyncOpenAI):
assert content1 == content2
@pytest.mark.asyncio
async def test_guided_grammar(server, client: openai.AsyncOpenAI):
simple_sql_grammar = """
start: select_statement
@ -847,6 +873,7 @@ number: "1" | "2"
assert content.strip() == ground_truth
@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras
"model_name",
@ -878,6 +905,7 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
assert len(logprobs.tokens) > 5
@pytest.mark.asyncio
async def test_long_seed(server, client: openai.AsyncOpenAI):
for seed in [
torch.iinfo(torch.long).min - 1,
@ -897,6 +925,7 @@ async def test_long_seed(server, client: openai.AsyncOpenAI):
or "less_than_equal" in exc_info.value.message)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[EMBEDDING_MODEL_NAME],
@ -935,6 +964,7 @@ async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI,
assert embeddings.usage.total_tokens == 5
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[EMBEDDING_MODEL_NAME],

View File

@ -1,7 +1,7 @@
import multiprocessing
import sys
import time
import pytest
import torch
from openai import OpenAI, OpenAIError
@ -10,6 +10,8 @@ from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.utils import get_open_port
pytestmark = pytest.mark.openai
class MyOPTForCausalLM(OPTForCausalLM):
@ -26,15 +28,16 @@ def server_function(port):
# register our dummy model
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
sys.argv = ["placeholder.py"] + \
("--model facebook/opt-125m --dtype"
f" float32 --api-key token-abc123 --port {port}").split()
("--model facebook/opt-125m --gpu-memory-utilization 0.10 "
f"--dtype float32 --api-key token-abc123 --port {port}").split()
import runpy
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')
def test_oot_registration_for_api_server():
port = get_open_port()
server = multiprocessing.Process(target=server_function, args=(port, ))
ctx = torch.multiprocessing.get_context()
server = ctx.Process(target=server_function, args=(port, ))
server.start()
client = OpenAI(
base_url=f"http://localhost:{port}/v1",

View File

@ -86,20 +86,18 @@ def generate(
def batched_generate(
llm,
llm: vllm.LLM,
inputs: List[Tuple[str, SamplingParams, Optional[LoRARequest]]],
):
for input in inputs:
prompt, sampling_param, lora_req = input
requests_data = llm._validate_and_prepare_requests(
# Add requests to the engine and run the engine
llm._validate_and_add_requests(
prompt,
sampling_param,
lora_request=lora_req,
)
# Add requests to the engine and run the engine
for request_data in requests_data:
llm._add_request(**request_data)
outputs = llm._run_engine(use_tqdm=True)
return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))]

View File

@ -35,28 +35,25 @@ def test_logits_processor_force_generate(
# test logits_processors when prompt_logprobs is not None
vllm_model.model._add_request(
prompt=example_prompts[0],
example_prompts[0],
params=params_with_logprobs,
prompt_token_ids=None,
)
# test prompt_logprobs is not None
vllm_model.model._add_request(
prompt=example_prompts[1],
example_prompts[1],
params=SamplingParams(
prompt_logprobs=3,
max_tokens=max_tokens,
),
prompt_token_ids=None,
)
# test grouped requests
vllm_model.model._add_request(
prompt=example_prompts[2],
example_prompts[2],
params=SamplingParams(max_tokens=max_tokens),
prompt_token_ids=None,
)
outputs = vllm_model.model._run_engine(False)
outputs = vllm_model.model._run_engine(use_tqdm=False)
assert outputs[0].outputs[0].text == enforced_answers * repeat_times

View File

@ -57,11 +57,7 @@ def test_random_sample_with_seed(
sampling_params_seed_1,
sampling_params_seed_2,
):
llm._add_request(
prompt=prompt,
prompt_token_ids=None,
params=params,
)
llm._add_request(prompt, params=params)
results = llm._run_engine(use_tqdm=False)
all_outputs = [[out.token_ids for out in output.outputs]

View File

@ -70,8 +70,15 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
for prompt in prompts:
hashes[-1].append([])
prompt_token_ids = tokenizer.encode(prompt)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
tokenizer.tokenizer.eos_token_id, lora_request)
seq = Sequence(seq_id,
inputs={
"prompt": prompt,
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": None,
},
block_size=block_size,
eos_token_id=tokenizer.tokenizer.eos_token_id,
lora_request=lora_request)
num_blocks = len(prompt_token_ids) // block_size
for idx in range(num_blocks):

53
tests/test_inputs.py Normal file
View File

@ -0,0 +1,53 @@
from typing import List
import pytest
from vllm.inputs import parse_and_batch_prompt
STRING_INPUTS = [
'',
'foo',
'foo bar',
'foo baz bar',
'foo bar qux baz',
]
TOKEN_INPUTS = [
[-1],
[1],
[1, 2],
[1, 3, 4],
[1, 2, 4, 3],
]
INPUTS_SLICES = [
slice(None, None, -1),
slice(None, None, 2),
slice(None, None, -2),
]
def test_parse_single_batch_empty():
with pytest.raises(ValueError, match="at least one prompt"):
parse_and_batch_prompt([])
with pytest.raises(ValueError, match="at least one prompt"):
parse_and_batch_prompt([[]])
@pytest.mark.parametrize('string_input', STRING_INPUTS)
def test_parse_single_batch_string_consistent(string_input: str):
assert parse_and_batch_prompt(string_input) \
== parse_and_batch_prompt([string_input])
@pytest.mark.parametrize('token_input', TOKEN_INPUTS)
def test_parse_single_batch_token_consistent(token_input: List[int]):
assert parse_and_batch_prompt(token_input) \
== parse_and_batch_prompt([token_input])
@pytest.mark.parametrize('inputs_slice', INPUTS_SLICES)
def test_parse_single_batch_string_slice(inputs_slice: slice):
assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \
== parse_and_batch_prompt(STRING_INPUTS[inputs_slice])

63
tests/test_utils.py Normal file
View File

@ -0,0 +1,63 @@
import pytest
from vllm.utils import deprecate_kwargs
from .utils import error_on_warning
def test_deprecate_kwargs_always():
@deprecate_kwargs("old_arg", is_deprecated=True)
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with pytest.warns(DeprecationWarning, match="'old_arg'"):
dummy(old_arg=1)
with error_on_warning():
dummy(new_arg=1)
def test_deprecate_kwargs_never():
@deprecate_kwargs("old_arg", is_deprecated=False)
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with error_on_warning():
dummy(old_arg=1)
with error_on_warning():
dummy(new_arg=1)
def test_deprecate_kwargs_dynamic():
is_deprecated = True
@deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated)
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with pytest.warns(DeprecationWarning, match="'old_arg'"):
dummy(old_arg=1)
with error_on_warning():
dummy(new_arg=1)
is_deprecated = False
with error_on_warning():
dummy(old_arg=1)
with error_on_warning():
dummy(new_arg=1)
def test_deprecate_kwargs_additional_message():
@deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd")
def dummy(*, old_arg: object = None, new_arg: object = None):
pass
with pytest.warns(DeprecationWarning, match="abcd"):
dummy(old_arg=1)

View File

@ -123,8 +123,11 @@ def create_sequence(prompt_token_ids=None):
prompt_token_ids = prompt_token_ids or [1]
return Sequence(
seq_id=0,
prompt="<s>",
prompt_token_ids=prompt_token_ids,
inputs={
"prompt": "<s>",
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": None,
},
block_size=16,
)

View File

@ -2,6 +2,8 @@ import os
import subprocess
import sys
import time
import warnings
from contextlib import contextmanager
import ray
import requests
@ -87,3 +89,15 @@ def multi_process_tensor_parallel(
ray.get(refs)
ray.shutdown()
@contextmanager
def error_on_warning():
"""
Within the scope of this context manager, tests will fail if any warning
is emitted.
"""
with warnings.catch_warnings():
warnings.simplefilter("error")
yield

View File

@ -5,6 +5,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptStrictInputs, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput, RequestOutput)
@ -16,6 +17,9 @@ __version__ = "0.4.2"
__all__ = [
"LLM",
"ModelRegistry",
"PromptStrictInputs",
"TextPrompt",
"TokensPrompt",
"SamplingParams",
"RequestOutput",
"CompletionOutput",

View File

@ -12,12 +12,13 @@ from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.executor.ray_utils import initialize_ray_cluster, ray
from vllm.inputs import LLMInputs, PromptInputs
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.usage.usage_lib import UsageContext
logger = init_logger(__name__)
@ -244,64 +245,69 @@ class _AsyncLLMEngine(LLMEngine):
return request_outputs
async def encode_request_async(
async def process_model_inputs_async(
self,
request_id: str, # pylint: disable=unused-argument
prompt: Optional[str],
prompt_token_ids: Optional[List[int]] = None,
request_id: str,
inputs: PromptInputs,
lora_request: Optional[LoRARequest] = None,
):
if prompt_token_ids is None:
assert prompt is not None
prompt_token_ids = await self.tokenizer.encode_async(
) -> LLMInputs:
if isinstance(inputs, str):
inputs = {"prompt": inputs}
if "prompt_token_ids" not in inputs:
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")
prompt_token_ids = await tokenizer.encode_async(
request_id=request_id,
prompt=prompt,
prompt=inputs["prompt"],
lora_request=lora_request)
return prompt_token_ids
else:
prompt_token_ids = inputs["prompt_token_ids"]
return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
async def add_request_async(
self,
request_id: str,
prompt: Optional[str],
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
if arrival_time is None:
arrival_time = time.time()
prompt_token_ids = await self.encode_request_async(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
return self.add_request(request_id,
prompt=prompt,
params=params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data)
processed_inputs = await self.process_model_inputs_async(
request_id=request_id, inputs=inputs, lora_request=lora_request)
self._add_processed_request(
request_id=request_id,
processed_inputs=processed_inputs,
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
)
async def check_health_async(self) -> None:
self.model_executor.check_health()
class AsyncLLMEngine:
"""An asynchronous wrapper for LLMEngine.
"""An asynchronous wrapper for :class:`LLMEngine`.
This class is used to wrap the LLMEngine class to make it asynchronous. It
uses asyncio to create a background loop that keeps processing incoming
requests. The LLMEngine is kicked by the generate method when there
are requests in the waiting queue. The generate method yields the outputs
from the LLMEngine to the caller.
This class is used to wrap the :class:`LLMEngine` class to make it
asynchronous. It uses asyncio to create a background loop that keeps
processing incoming requests. The :class:`LLMEngine` is kicked by the
generate method when there are requests in the waiting queue. The generate
method yields the outputs from the :class:`LLMEngine` to the caller.
NOTE: For the comprehensive list of arguments, see `LLMEngine`.
NOTE: For the comprehensive list of arguments, see :class:`LLMEngine`.
Args:
worker_use_ray: Whether to use Ray for model workers. Required for
@ -315,8 +321,8 @@ class AsyncLLMEngine:
being printed in log.
start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call.
*args: Arguments for LLMEngine.
*kwargs: Arguments for LLMEngine.
*args: Arguments for :class:`LLMEngine`.
**kwargs: Arguments for :class:`LLMEngine`.
"""
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
@ -526,22 +532,26 @@ class AsyncLLMEngine:
async def add_request(
self,
request_id: str,
prompt: Optional[str],
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> AsyncStream:
if self.log_requests:
shortened_prompt = prompt
shortened_token_ids = prompt_token_ids
if self.max_log_len is not None:
if isinstance(inputs, str):
shortened_prompt = inputs
shortened_token_ids = None
else:
shortened_prompt = inputs.get("prompt")
shortened_token_ids = inputs.get("prompt_token_ids")
max_log_len = self.max_log_len
if max_log_len is not None:
if shortened_prompt is not None:
shortened_prompt = shortened_prompt[:self.max_log_len]
shortened_prompt = shortened_prompt[:max_log_len]
if shortened_token_ids is not None:
shortened_token_ids = shortened_token_ids[:self.
max_log_len]
shortened_token_ids = shortened_token_ids[:max_log_len]
logger.info(
"Received request %s: prompt: %r, "
"params: %s, prompt_token_ids: %s, "
@ -562,39 +572,33 @@ class AsyncLLMEngine:
arrival_time = time.time()
if self.engine_use_ray:
prompt_token_ids = await (
self.engine.encode_request_async.remote( # type: ignore
processed_inputs = await self.engine.process_model_inputs_async \
.remote( # type: ignore
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request))
inputs=inputs,
lora_request=lora_request)
else:
prompt_token_ids = await self.engine.encode_request_async(
processed_inputs = await self.engine.process_model_inputs_async(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
inputs=inputs,
lora_request=lora_request)
stream = self._request_tracker.add_request(
request_id,
prompt=prompt,
inputs=processed_inputs,
params=params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data,
)
return stream
async def generate(
self,
prompt: Optional[str],
inputs: PromptInputs,
sampling_params: SamplingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None
) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request.
@ -603,14 +607,12 @@ class AsyncLLMEngine:
from the LLMEngine to the caller.
Args:
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data per request.
Yields:
The output `RequestOutput` objects from the LLMEngine
@ -659,24 +661,20 @@ class AsyncLLMEngine:
>>> # Process and return the final output
>>> ...
"""
async for output in self.process_request(
async for output in self._process_request(
request_id,
prompt,
inputs,
sampling_params,
prompt_token_ids,
lora_request,
multi_modal_data,
lora_request=lora_request,
):
yield output
yield LLMEngine.validate_output(output, RequestOutput)
async def encode(
self,
prompt: Optional[str],
inputs: PromptInputs,
pooling_params: PoolingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None
) -> AsyncIterator[EmbeddingRequestOutput]:
"""Generate outputs for a request from an embedding model.
@ -685,14 +683,12 @@ class AsyncLLMEngine:
from the LLMEngine to the caller.
Args:
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data per request.
Yields:
The output `EmbeddingRequestOutput` objects from the LLMEngine
@ -739,24 +735,21 @@ class AsyncLLMEngine:
>>> # Process and return the final output
>>> ...
"""
async for output in self.process_request(
async for output in self._process_request(
request_id,
prompt,
inputs,
pooling_params,
prompt_token_ids,
lora_request,
multi_modal_data,
lora_request=lora_request,
):
yield output
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
async def process_request(
async def _process_request(
self,
request_id: str,
prompt: Optional[str],
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None,
*,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Common logic to process requests with SamplingParams or
PoolingParams."""
@ -764,12 +757,10 @@ class AsyncLLMEngine:
stream = await self.add_request(
request_id,
prompt,
inputs,
params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data,
)
try:

View File

@ -1,5 +1,8 @@
import time
from typing import Iterable, List, Optional, Type, Union
from contextlib import contextmanager
from typing import TYPE_CHECKING, ClassVar, Iterable, List, Optional
from typing import Sequence as GenericSequence
from typing import Type, TypeVar, Union
from transformers import GenerationConfig, PreTrainedTokenizer
@ -18,6 +21,7 @@ from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import LLMInputs, PromptInputs
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
@ -25,8 +29,8 @@ from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
MultiModalData, PoolerOutput, SamplerOutput,
Sequence, SequenceGroup, SequenceGroupMetadata,
PoolerOutput, SamplerOutput, Sequence,
SequenceGroup, SequenceGroupMetadata,
SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
@ -50,6 +54,9 @@ def _load_generation_config_dict(model_config: ModelConfig):
return {}
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
class LLMEngine:
"""An LLM engine that receives requests and generates texts.
@ -60,11 +67,11 @@ class LLMEngine:
iteration-level scheduling and efficient memory management to maximize the
serving throughput.
The `LLM` class wraps this class for offline batched inference and the
`AsyncLLMEngine` class wraps this class for online serving.
The :class:`~vllm.LLM` class wraps this class for offline batched inference
and the :class:`AsyncLLMEngine` class wraps this class for online serving.
NOTE: The config arguments are derived from the `EngineArgs` class. For the
comprehensive list of arguments, see `EngineArgs`.
NOTE: The config arguments are derived from the :class:`~vllm.EngineArgs`
class. For the comprehensive list of arguments, see :ref:`engine_args`.
Args:
model_config: The configuration related to the LLM model.
@ -81,9 +88,60 @@ class LLMEngine:
executor_class: The model executor class for managing distributed
execution.
log_stats: Whether to log statistics.
usage_context: Specified entry point, used for usage info collection
usage_context: Specified entry point, used for usage info collection.
"""
DO_VALIDATE_OUTPUT: ClassVar[bool] = False
"""A flag to toggle whether to validate the type of request output."""
@classmethod
@contextmanager
def enable_output_validation(cls):
cls.DO_VALIDATE_OUTPUT = True
yield
cls.DO_VALIDATE_OUTPUT = False
@classmethod
def validate_output(
cls,
output: object,
output_type: Type[_O],
) -> _O:
do_validate = cls.DO_VALIDATE_OUTPUT
if ((TYPE_CHECKING or do_validate)
and not isinstance(output, output_type)):
raise TypeError(f"Expected output of type {output_type}, "
f"but found type {type(output)}")
return output
@classmethod
def validate_outputs(
cls,
outputs: GenericSequence[object],
output_type: Type[_O],
) -> List[_O]:
do_validate = cls.DO_VALIDATE_OUTPUT
outputs_: List[_O]
if TYPE_CHECKING or do_validate:
outputs_ = []
for output in outputs:
if not isinstance(output, output_type):
raise TypeError(f"Expected output of type {output_type}, "
f"but found type {type(output)}")
outputs_.append(output)
else:
outputs_ = outputs
return outputs_
tokenizer: Optional[BaseTokenizerGroup]
def __init__(
self,
model_config: ModelConfig,
@ -151,12 +209,11 @@ class LLMEngine:
self.log_stats = log_stats
if not self.model_config.skip_tokenizer_init:
self.tokenizer: BaseTokenizerGroup
self._init_tokenizer()
self.tokenizer = self._init_tokenizer()
self.detokenizer = Detokenizer(self.tokenizer)
else:
self.detokenizer = None
self.tokenizer = None
self.detokenizer = None
self.seq_counter = Counter()
self.generation_config_fields = _load_generation_config_dict(
@ -318,14 +375,26 @@ class LLMEngine:
if model_executor := getattr(self, "model_executor", None):
model_executor.shutdown()
MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because "
"skip_tokenizer_init is True")
def get_tokenizer_group(
self,
fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup:
if self.tokenizer is None:
raise ValueError(fail_msg)
return self.tokenizer
def get_tokenizer(self) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer(None)
return self.get_tokenizer_group().get_lora_tokenizer(None)
def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
return self.get_tokenizer_group().get_lora_tokenizer(
sequence.lora_request)
def _init_tokenizer(self, **tokenizer_init_kwargs):
def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup:
init_kwargs = dict(
tokenizer_id=self.model_config.tokenizer,
enable_lora=bool(self.lora_config),
@ -335,8 +404,9 @@ class LLMEngine:
trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision)
init_kwargs.update(tokenizer_init_kwargs)
self.tokenizer = get_tokenizer_group(
self.parallel_config.tokenizer_pool_config, **init_kwargs)
return get_tokenizer_group(self.parallel_config.tokenizer_pool_config,
**init_kwargs)
def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
@ -346,29 +416,85 @@ class LLMEngine:
self.lora_config.verify_with_scheduler_config(
self.scheduler_config)
def encode_request(
def _get_eos_token_id(
self, lora_request: Optional[LoRARequest]) -> Optional[int]:
if self.tokenizer is None:
logger.warning("Using None for EOS token id because tokenizer "
"is not initialized")
return None
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
def _add_processed_request(
self,
request_id: str, # pylint: disable=unused-argument
prompt: Optional[str],
prompt_token_ids: Optional[List[int]] = None,
request_id: str,
processed_inputs: LLMInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest],
) -> None:
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
eos_token_id = self._get_eos_token_id(lora_request)
seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
lora_request)
# Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams):
seq_group = self._create_sequence_group_with_sampling(
request_id,
seq,
params,
arrival_time=arrival_time,
lora_request=lora_request,
)
elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling(
request_id,
seq,
params,
arrival_time=arrival_time,
lora_request=lora_request,
)
else:
raise ValueError(
"Either SamplingParams or PoolingParams must be provided.")
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
def process_model_inputs(
self,
request_id: str,
inputs: PromptInputs,
lora_request: Optional[LoRARequest] = None,
):
if prompt_token_ids is None:
assert prompt is not None
prompt_token_ids = self.tokenizer.encode(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
return prompt_token_ids
) -> LLMInputs:
if isinstance(inputs, str):
inputs = {"prompt": inputs}
if "prompt_token_ids" not in inputs:
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")
prompt_token_ids = tokenizer.encode(request_id=request_id,
prompt=inputs["prompt"],
lora_request=lora_request)
else:
prompt_token_ids = inputs["prompt_token_ids"]
return LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
multi_modal_data=inputs.get("multi_modal_data"))
def add_request(
self,
request_id: str,
prompt: Optional[str],
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None:
"""Add a request to the engine's request pool.
@ -378,15 +504,14 @@ class LLMEngine:
Args:
request_id: The unique ID of the request.
prompt: The prompt string. Can be None if prompt_token_ids is
provided.
params: Parameters for sampling or pooling. SamplingParams
for text generation. PoolingParams for pooling.
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
params: Parameters for sampling or pooling.
:class:`~vllm.SamplingParams` for text generation.
:class:`~vllm.PoolingParams` for pooling.
arrival_time: The arrival time of the request. If None, we use
the current monotonic time.
multi_modal_data: Multi modal data per request.
Details:
- Set arrival_time to the current time if it is None.
@ -417,59 +542,26 @@ class LLMEngine:
"not enabled!")
if arrival_time is None:
arrival_time = time.time()
prompt_token_ids = self.encode_request(
processed_inputs = self.process_model_inputs(request_id=request_id,
inputs=inputs,
lora_request=lora_request)
self._add_processed_request(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
eos_token_id = None
if self.tokenizer:
eos_token_id = self.tokenizer.get_lora_tokenizer(
lora_request).eos_token_id
else:
logger.warning("Use None for EOS token id because tokenizer is "
"not initialized")
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
eos_token_id, lora_request)
# Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams):
seq_group = self._create_sequence_group_with_sampling(
request_id,
seq,
params,
arrival_time,
lora_request,
multi_modal_data,
)
elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling(
request_id,
seq,
params,
arrival_time,
lora_request,
multi_modal_data,
)
else:
raise ValueError(
"Either SamplingParams or PoolingParams must be provided.")
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
processed_inputs=processed_inputs,
params=params,
arrival_time=arrival_time,
lora_request=lora_request,
)
def _create_sequence_group_with_sampling(
self,
request_id: str,
seq: Sequence,
sampling_params: SamplingParams,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
arrival_time: float,
lora_request: Optional[LoRARequest],
) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams."""
max_logprobs = self.get_model_config().max_logprobs
@ -495,8 +587,7 @@ class LLMEngine:
seqs=[seq],
arrival_time=arrival_time,
sampling_params=sampling_params,
lora_request=lora_request,
multi_modal_data=multi_modal_data)
lora_request=lora_request)
return seq_group
@ -505,9 +596,8 @@ class LLMEngine:
request_id: str,
seq: Sequence,
pooling_params: PoolingParams,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
arrival_time: float,
lora_request: Optional[LoRARequest],
) -> SequenceGroup:
"""Creates a SequenceGroup with PoolingParams."""
# Defensive copy of PoolingParams, which are used by the pooler
@ -517,7 +607,6 @@ class LLMEngine:
seqs=[seq],
arrival_time=arrival_time,
lora_request=lora_request,
multi_modal_data=multi_modal_data,
pooling_params=pooling_params)
return seq_group
@ -570,7 +659,7 @@ class LLMEngine:
def _process_model_outputs(
self,
output: List[Union[SamplerOutput, PoolerOutput]],
output: GenericSequence[Union[SamplerOutput, PoolerOutput]],
scheduled_seq_groups: List[ScheduledSequenceGroup],
ignored_seq_groups: List[SequenceGroup],
seq_group_metadata_list: List[SequenceGroupMetadata],
@ -585,7 +674,7 @@ class LLMEngine:
# Organize outputs by [sequence group][step] instead of
# [step][sequence group].
output_by_sequence_group = create_output_by_sequence_group(
sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups))
output, num_seq_groups=len(scheduled_seq_groups))
# Update the scheduled sequence groups with the model outputs.
for scheduled_seq_group, outputs, seq_group_meta in zip(

View File

@ -1,18 +1,20 @@
from typing import List
from typing import Sequence as GenericSequence
from typing import Union
from vllm.sequence import SamplerOutput, SequenceGroupOutput
from vllm.sequence import PoolerOutput, SamplerOutput, SequenceGroupOutput
def create_output_by_sequence_group(
sampler_outputs: List[SamplerOutput],
outputs: GenericSequence[Union[SamplerOutput, PoolerOutput]],
num_seq_groups: int) -> List[List[SequenceGroupOutput]]:
"""Helper method which transforms a 2d list organized by
[step][sequence group] into [sequence group][step].
"""
output_by_sequence_group: List[List[SamplerOutput]] = [
output_by_sequence_group: List[List[SequenceGroupOutput]] = [
[] for _ in range(num_seq_groups)
]
for step in sampler_outputs:
for step in outputs:
for i, sequence_group_output in enumerate(step):
output_by_sequence_group[i].append(sequence_group_output)

View File

@ -1,11 +1,14 @@
from typing import List, Optional, Union
from contextlib import contextmanager
from typing import ClassVar, List, Optional, Sequence, Union, cast, overload
import torch
from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.inputs import (PromptInputs, PromptStrictInputs, TextPrompt,
TextTokensPrompt, TokensPrompt,
parse_and_batch_prompt)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
@ -13,7 +16,7 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.sequence import MultiModalData
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter
from vllm.utils import Counter, deprecate_kwargs
logger = init_logger(__name__)
@ -28,8 +31,10 @@ class LLM:
mechanism and efficient memory management.
NOTE: This class is intended to be used for offline inference. For online
serving, use the `AsyncLLMEngine` class instead.
NOTE: For the comprehensive list of arguments, see `EngineArgs`.
serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
NOTE: For the comprehensive list of arguments, see
:class:`~vllm.EngineArgs`.
Args:
model: The name or path of a HuggingFace Transformers model.
@ -81,6 +86,18 @@ class LLM:
disable_custom_all_reduce: See ParallelConfig
"""
DEPRECATE_LEGACY: ClassVar[bool] = False
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
@classmethod
@contextmanager
def deprecate_legacy_api(cls):
cls.DEPRECATE_LEGACY = True
yield
cls.DEPRECATE_LEGACY = False
def __init__(
self,
model: str,
@ -138,15 +155,101 @@ class LLM:
) -> None:
self.llm_engine.tokenizer.tokenizer = tokenizer
@overload # LEGACY: single (prompt + optional token ids)
def generate(
self,
prompts: Optional[Union[str, List[str]]] = None,
prompts: str,
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
prompt_token_ids: Optional[List[int]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
...
@overload # LEGACY: multi (prompt + optional token ids)
def generate(
self,
prompts: List[str],
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
...
@overload # LEGACY: single (token ids + optional prompt)
def generate(
self,
prompts: Optional[str] = None,
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
*,
prompt_token_ids: List[int],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
...
@overload # LEGACY: multi (token ids + optional prompt)
def generate(
self,
prompts: Optional[List[str]] = None,
sampling_params: Optional[Union[SamplingParams,
List[SamplingParams]]] = None,
*,
prompt_token_ids: List[List[int]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
...
@overload # LEGACY: single or multi token ids [pos-only]
def generate(
self,
prompts: None,
sampling_params: None,
prompt_token_ids: Union[List[int], List[List[int]]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
...
@overload
def generate(
self,
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
/, # We may enable `inputs` keyword after removing the old API
*,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
) -> List[RequestOutput]:
...
@deprecate_kwargs("prompts",
"prompt_token_ids",
"multi_modal_data",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter "
"instead.")
def generate(
self,
prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
Optional[Union[str, List[str]]]] = None,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[RequestOutput]:
"""Generates the completions for the input prompts.
@ -155,49 +258,138 @@ class LLM:
into a single list and pass it to this method.
Args:
prompts: A list of prompts to generate completions for.
inputs: A list of inputs to generate completions for.
sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt.
When it is a list, the list must have the same length as the
prompts and it is paired one by one with the prompt.
prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data.
Returns:
A list of `RequestOutput` objects containing the
generated completions in the same order as the input prompts.
"""
if prompt_token_ids is not None or multi_modal_data is not None:
inputs = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data,
)
else:
inputs = cast(
Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
prompts)
if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams()
requests_data = self._validate_and_prepare_requests(
prompts,
sampling_params,
prompt_token_ids,
lora_request,
multi_modal_data,
self._validate_and_add_requests(
inputs=inputs,
params=sampling_params,
lora_request=lora_request,
)
# Add requests to the engine and run the engine
for request_data in requests_data:
self._add_request(**request_data)
return self._run_engine(use_tqdm)
outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, RequestOutput)
@overload # LEGACY: single (prompt + optional token ids)
def encode(
self,
prompts: Optional[Union[str, List[str]]] = None,
prompts: str,
pooling_params: Optional[Union[PoolingParams,
List[PoolingParams]]] = None,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[List[int]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
...
@overload # LEGACY: multi (prompt + optional token ids)
def encode(
self,
prompts: List[str],
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
...
@overload # LEGACY: single (token ids + optional prompt)
def encode(
self,
prompts: Optional[str] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
*,
prompt_token_ids: List[int],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
...
@overload # LEGACY: multi (token ids + optional prompt)
def encode(
self,
prompts: Optional[List[str]] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
*,
prompt_token_ids: List[List[int]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
...
@overload # LEGACY: single or multi token ids [pos-only]
def encode(
self,
prompts: None,
pooling_params: None,
prompt_token_ids: Union[List[int], List[List[int]]],
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
...
@overload
def encode(
self,
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
/, # We may enable `inputs` keyword after removing the old API
*,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
) -> List[EmbeddingRequestOutput]:
...
@deprecate_kwargs("prompts",
"prompt_token_ids",
"multi_modal_data",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter "
"instead.")
def encode(
self,
prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
Optional[Union[str, List[str]]]] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[EmbeddingRequestOutput]:
"""Generates the completions for the input prompts.
@ -206,124 +398,133 @@ class LLM:
into a single list and pass it to this method.
Args:
prompts: A list of prompts to generate completions for.
inputs: The inputs to the LLM. You may pass a sequence of inputs for
batch inference. See :class:`~vllm.inputs.PromptStrictInputs`
for more details about the format of each input.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
multi_modal_data: Multi modal data.
Returns:
A list of `EmbeddingRequestOutput` objects containing the
generated embeddings in the same order as the input prompts.
"""
if prompt_token_ids is not None or multi_modal_data is not None:
inputs = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids,
multi_modal_data=multi_modal_data,
)
else:
inputs = cast(
Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
prompts)
if pooling_params is None:
# Use default pooling params.
pooling_params = PoolingParams()
requests_data = self._validate_and_prepare_requests(
prompts,
pooling_params,
prompt_token_ids,
lora_request,
multi_modal_data,
self._validate_and_add_requests(
inputs=inputs,
params=pooling_params,
lora_request=lora_request,
)
# Add requests to the engine and run the engine
for request_data in requests_data:
self._add_request(**request_data)
outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput)
return self._run_engine(use_tqdm)
def _validate_and_prepare_requests(
# LEGACY
def _convert_v1_inputs(
self,
prompts: Optional[Union[str, List[str]]],
params: Union[Union[SamplingParams, PoolingParams],
List[Union[SamplingParams,
PoolingParams]]], # Unified parameter
prompt_token_ids: Optional[List[List[int]]] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> List[dict]:
"""Validates and prepares request data for adding to the engine.
Ensures prompts and token IDs are consistent, and returns a list of
dictionaries with request data for further processing.
"""
if prompts is None and prompt_token_ids is None:
raise ValueError("Either prompts or prompt_token_ids must be "
"provided.")
if self.llm_engine.model_config.skip_tokenizer_init \
and prompts is not None:
raise ValueError("prompts must be None if skip_tokenizer_init "
"is True")
if isinstance(prompts, str):
# Convert a single prompt to a list.
prompts = [prompts]
if (prompts is not None and prompt_token_ids is not None
and len(prompts) != len(prompt_token_ids)):
raise ValueError("The lengths of prompts and prompt_token_ids "
"must be the same.")
prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
multi_modal_data: Optional[MultiModalData],
):
# skip_tokenizer_init is now checked in engine
if prompts is not None:
prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
if prompt_token_ids is not None:
prompt_token_ids = [
p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
]
num_requests = None
if prompts is not None:
num_requests = len(prompts)
else:
assert prompt_token_ids is not None
if prompt_token_ids is not None:
if (num_requests is not None
and num_requests != len(prompt_token_ids)):
raise ValueError("The lengths of prompts and prompt_token_ids "
"must be the same.")
num_requests = len(prompt_token_ids)
if num_requests is None:
raise ValueError("Either prompts or prompt_token_ids must be "
"provided.")
inputs: List[PromptInputs] = []
for i in range(num_requests):
if prompts is not None:
if prompt_token_ids is not None:
item = TextTokensPrompt(
prompt=prompts[i],
prompt_token_ids=prompt_token_ids[i])
else:
item = TextPrompt(prompt=prompts[i])
else:
if prompt_token_ids is not None:
item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
else:
raise AssertionError
if multi_modal_data is not None:
item["multi_modal_data"] = multi_modal_data
inputs.append(item)
return inputs
def _validate_and_add_requests(
self,
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]],
lora_request: Optional[LoRARequest],
) -> None:
if isinstance(inputs, (str, dict)):
# Convert a single prompt to a list.
inputs = [inputs]
num_requests = len(inputs)
if isinstance(params, list) and len(params) != num_requests:
raise ValueError("The lengths of prompts and params "
"must be the same.")
if multi_modal_data:
multi_modal_data.data = multi_modal_data.data.to(torch.float16)
# Add requests to the engine.
requests_data = []
for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None
token_ids = None if prompt_token_ids is None else prompt_token_ids[
i]
multi_modal_item = MultiModalData(
type=multi_modal_data.type,
data=multi_modal_data.data[i].unsqueeze(0),
) if multi_modal_data else None
requests_data.append({
"prompt":
prompt,
"params":
params[i] if isinstance(params, list) else params,
"prompt_token_ids":
token_ids,
"lora_request":
lora_request,
"multi_modal_data":
multi_modal_item,
})
return requests_data
for i, request_inputs in enumerate(inputs):
self._add_request(
request_inputs,
params[i] if isinstance(params, Sequence) else params,
lora_request=lora_request,
)
def _add_request(
self,
prompt: Optional[str],
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
prompt_token_ids: Optional[List[int]],
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(request_id,
prompt,
inputs,
params,
prompt_token_ids,
lora_request=lora_request,
multi_modal_data=multi_modal_data)
lora_request=lora_request)
def _run_engine(
self, use_tqdm: bool
self, *, use_tqdm: bool
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
# Initialize tqdm.
if use_tqdm:
@ -355,5 +556,4 @@ class LLM:
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
outputs = sorted(outputs, key=lambda x: int(x.request_id))
return outputs
return sorted(outputs, key=lambda x: int(x.request_id))

View File

@ -176,9 +176,15 @@ class OpenAIServingChat(OpenAIServing):
except ValueError as e:
return self.create_error_response(str(e))
result_generator = self.engine.generate(prompt_text, sampling_params,
request_id, prompt_ids,
lora_request)
result_generator = self.engine.generate(
{
"prompt": prompt_text,
"prompt_token_ids": prompt_ids
},
sampling_params,
request_id,
lora_request,
)
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(

View File

@ -119,12 +119,17 @@ class OpenAIServingCompletion(OpenAIServing):
truncate_prompt_tokens)
prompt_ids, prompt_text = prompt_formats
generators.append(
self.engine.generate(prompt_text,
sampling_params,
f"{request_id}-{i}",
prompt_token_ids=prompt_ids,
lora_request=lora_request))
generator = self.engine.generate(
{
"prompt": prompt_text,
"prompt_token_ids": prompt_ids
},
sampling_params,
f"{request_id}-{i}",
lora_request=lora_request,
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))

View File

@ -1,5 +1,5 @@
import time
from typing import AsyncIterator, List, Tuple
from typing import AsyncIterator, List, Optional, Tuple
from fastapi import Request
@ -100,11 +100,16 @@ class OpenAIServingEmbedding(OpenAIServing):
prompt_ids, prompt_text = prompt_formats
generators.append(
self.engine.generate(prompt_text,
pooling_params,
f"{request_id}-{i}",
prompt_token_ids=prompt_ids))
generator = self.engine.encode(
{
"prompt": prompt_text,
"prompt_token_ids": prompt_ids
},
pooling_params,
f"{request_id}-{i}",
)
generators.append(generator)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
@ -113,16 +118,21 @@ class OpenAIServingEmbedding(OpenAIServing):
int, EmbeddingRequestOutput]] = merge_async_iterators(*generators)
# Non-streaming response
final_res_batch: EmbeddingRequestOutput = [None] * len(prompts)
async for i, res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(f"{request_id}-{i}")
# TODO: Use a vllm-specific Validation Error
return self.create_error_response("Client disconnected")
final_res_batch[i] = res
response = request_output_to_embedding_response(
final_res_batch, request_id, created_time, model_name)
final_res_batch: List[Optional[EmbeddingRequestOutput]]
final_res_batch = [None] * len(prompts)
try:
async for i, res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(f"{request_id}-{i}")
# TODO: Use a vllm-specific Validation Error
return self.create_error_response("Client disconnected")
final_res_batch[i] = res
response = request_output_to_embedding_response(
final_res_batch, request_id, created_time, model_name)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
return response

View File

@ -143,7 +143,8 @@ class OpenAIServing:
return json_str
async def _check_model(
self, request: Union[CompletionRequest, ChatCompletionRequest]
self, request: Union[CompletionRequest, ChatCompletionRequest,
EmbeddingRequest]
) -> Optional[ErrorResponse]:
if request.model in self.served_model_names:
return None
@ -155,7 +156,8 @@ class OpenAIServing:
status_code=HTTPStatus.NOT_FOUND)
def _maybe_get_lora(
self, request: Union[CompletionRequest, ChatCompletionRequest]
self, request: Union[CompletionRequest, ChatCompletionRequest,
EmbeddingRequest]
) -> Optional[LoRARequest]:
if request.model in self.served_model_names:
return None

130
vllm/inputs.py Normal file
View File

@ -0,0 +1,130 @@
from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence,
TypedDict, Union, cast, overload)
from typing_extensions import NotRequired
if TYPE_CHECKING:
from vllm.sequence import MultiModalData
class ParsedText(TypedDict):
content: str
is_tokens: Literal[False]
class ParsedTokens(TypedDict):
content: List[int]
is_tokens: Literal[True]
# https://github.com/vllm-project/vllm/pull/4028
@overload
def parse_and_batch_prompt(
prompt: Union[str, List[str]]) -> Sequence[ParsedText]:
...
@overload
def parse_and_batch_prompt(
prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]:
...
def parse_and_batch_prompt(
prompt: Union[str, List[str], List[int], List[List[int]]],
) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
if isinstance(prompt, str):
# case 1: a string
return [ParsedText(content=prompt, is_tokens=False)]
if isinstance(prompt, list):
if len(prompt) == 0:
raise ValueError("please provide at least one prompt")
if isinstance(prompt[0], str):
# case 2: array of strings
return [
ParsedText(content=elem, is_tokens=False)
for elem in cast(List[str], prompt)
]
if isinstance(prompt[0], int):
# case 3: array of tokens
elem = cast(List[int], prompt)
return [ParsedTokens(content=elem, is_tokens=True)]
if isinstance(prompt[0], list):
if len(prompt[0]) == 0:
raise ValueError("please provide at least one prompt")
if isinstance(prompt[0][0], int):
# case 4: array of token arrays
return [
ParsedTokens(content=elem, is_tokens=True)
for elem in cast(List[List[int]], prompt)
]
raise ValueError("prompt must be a string, array of strings, "
"array of tokens, or array of token arrays")
class TextPrompt(TypedDict):
"""Schema for a text prompt."""
prompt: str
"""The input text to be tokenized before passing to the model."""
multi_modal_data: NotRequired["MultiModalData"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
class TokensPrompt(TypedDict):
"""Schema for a tokenized prompt."""
prompt_token_ids: List[int]
"""A list of token IDs to pass to the model."""
multi_modal_data: NotRequired["MultiModalData"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
class TextTokensPrompt(TypedDict):
"""It is assumed that :attr:`prompt` is consistent with
:attr:`prompt_token_ids`. This is currently used in
:class:`AsyncLLMEngine` for logging both the text and token IDs."""
prompt: str
"""The prompt text."""
prompt_token_ids: List[int]
"""The token IDs of the prompt. If None, we use the
tokenizer to convert the prompts to token IDs."""
multi_modal_data: NotRequired["MultiModalData"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
PromptStrictInputs = Union[str, TextPrompt, TokensPrompt]
"""
The inputs to the LLM, which can take one of the following forms:
- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
"""
PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt]
"""Same as :const:`PromptStrictInputs` but additionally accepts
:class:`TextTokensPrompt`."""
class LLMInputs(TypedDict):
prompt_token_ids: List[int]
prompt: Optional[str]
multi_modal_data: Optional["MultiModalData"]

View File

@ -1,4 +1,5 @@
import time
from dataclasses import dataclass
from typing import List, Optional, Union
from vllm.lora.request import LoRARequest
@ -6,6 +7,7 @@ from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
SequenceGroup, SequenceStatus)
@dataclass
class CompletionOutput:
"""The output data of one completion output of a request.
@ -24,25 +26,14 @@ class CompletionOutput:
lora_request: The LoRA request that was used to generate the output.
"""
def __init__(
self,
index: int,
text: str,
token_ids: List[int],
cumulative_logprob: float,
logprobs: Optional[SampleLogprobs],
finish_reason: Optional[str] = None,
stop_reason: Union[int, str, None] = None,
lora_request: Optional[LoRARequest] = None,
) -> None:
self.index = index
self.text = text
self.token_ids = token_ids
self.cumulative_logprob = cumulative_logprob
self.logprobs = logprobs
self.finish_reason = finish_reason
self.stop_reason = stop_reason
self.lora_request = lora_request
index: int
text: str
token_ids: List[int]
cumulative_logprob: float
logprobs: Optional[SampleLogprobs]
finish_reason: Optional[str] = None
stop_reason: Union[int, str, None] = None
lora_request: Optional[LoRARequest] = None
def finished(self) -> bool:
return self.finish_reason is not None
@ -57,6 +48,7 @@ class CompletionOutput:
f"stop_reason={self.stop_reason})")
@dataclass
class EmbeddingOutput:
"""The output data of one completion output of a request.
@ -65,15 +57,11 @@ class EmbeddingOutput:
length of vector depends on the model as listed in the embedding guide.
"""
def __init__(
self,
embedding: List[float],
) -> None:
self.embedding = embedding
embedding: List[float]
def __repr__(self) -> str:
return (f"EmbeddingOutput("
f"embedding={len(self.embedding)}")
f"embedding={len(self.embedding)})")
class RequestOutput:
@ -93,7 +81,7 @@ class RequestOutput:
def __init__(
self,
request_id: str,
prompt: str,
prompt: Optional[str],
prompt_token_ids: List[int],
prompt_logprobs: Optional[PromptLogprobs],
outputs: List[CompletionOutput],
@ -183,7 +171,7 @@ class EmbeddingRequestOutput:
finished (bool): A flag indicating whether the embedding is completed.
"""
def __init__(self, request_id: str, outputs: 'EmbeddingOutput',
def __init__(self, request_id: str, outputs: "EmbeddingOutput",
prompt_token_ids: List[int], finished: bool):
self.request_id = request_id
self.prompt_token_ids = prompt_token_ids

View File

@ -6,6 +6,7 @@ from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from vllm.block import LogicalTokenBlock
from vllm.inputs import LLMInputs
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
@ -210,8 +211,7 @@ class Sequence:
Args:
seq_id: The ID of the sequence.
prompt: The prompt of the sequence.
prompt_token_ids: The token IDs of the prompt.
inputs: The inputs of the sequence.
block_size: The block size of the sequence. Should be the same as the
block size used by the block manager and cache engine.
lora_request: LoRA request.
@ -220,25 +220,24 @@ class Sequence:
def __init__(
self,
seq_id: int,
prompt: str,
prompt_token_ids: List[int],
inputs: LLMInputs,
block_size: int,
eos_token_id: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
) -> None:
self.seq_id = seq_id
self.prompt = prompt
self.inputs = inputs
self.block_size = block_size
self.eos_token_id = eos_token_id
self.lora_request = lora_request
self.data: SequenceData = SequenceData(prompt_token_ids)
self.data = SequenceData(self.prompt_token_ids)
self.output_logprobs: SampleLogprobs = []
self.output_text = ""
self.logical_token_blocks: List[LogicalTokenBlock] = []
# Initialize the logical token blocks with the prompt token ids.
self._append_tokens_to_blocks(prompt_token_ids)
self._append_tokens_to_blocks(self.prompt_token_ids)
self.status = SequenceStatus.WAITING
self.stop_reason: Union[int, str, None] = None
@ -248,6 +247,18 @@ class Sequence:
# Input + output tokens
self.tokens: Optional[List[str]] = None
@property
def prompt(self) -> Optional[str]:
return self.inputs["prompt"]
@property
def prompt_token_ids(self) -> List[int]:
return self.inputs["prompt_token_ids"]
@property
def multi_modal_data(self) -> Optional["MultiModalData"]:
return self.inputs["multi_modal_data"]
@property
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
@ -415,7 +426,6 @@ class SequenceGroup:
sampling_params: The sampling parameters used to generate the outputs.
arrival_time: The arrival time of the request.
lora_request: LoRA request.
multi_modal_data: Multi modal data associated with the request.
embeddings: The embeddings vectors of the prompt of the sequence group
for an embedding model.
pooling_params: The pooling parameters used to generate the pooling
@ -429,7 +439,6 @@ class SequenceGroup:
arrival_time: float,
sampling_params: Optional[SamplingParams] = None,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
embeddings: Optional[List[float]] = None,
pooling_params: Optional[PoolingParams] = None,
) -> None:
@ -444,12 +453,11 @@ class SequenceGroup:
self.lora_request = lora_request
self.prompt_logprobs: Optional[PromptLogprobs] = None
self.state = SequenceGroupState()
self.multi_modal_data = multi_modal_data
self.embeddings = embeddings
self.pooling_params = pooling_params
@property
def prompt(self) -> str:
def prompt(self) -> Optional[str]:
# All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).prompt
@ -458,7 +466,13 @@ class SequenceGroup:
def prompt_token_ids(self) -> List[int]:
# All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).data.prompt_token_ids
return next(iter(self.seqs_dict.values())).prompt_token_ids
@property
def multi_modal_data(self) -> Optional[MultiModalData]:
# All sequences in the group should have the same multi-modal data.
# We use the multi-modal data of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).multi_modal_data
@property
def lora_int_id(self) -> int:

View File

@ -11,7 +11,7 @@ import threading
import uuid
import warnings
from collections import defaultdict
from functools import lru_cache, partial
from functools import lru_cache, partial, wraps
from platform import uname
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
Hashable, List, Optional, OrderedDict, Tuple, TypeVar,
@ -658,3 +658,44 @@ def enable_trace_function_call_for_thread() -> None:
filename)
os.makedirs(os.path.dirname(log_path), exist_ok=True)
enable_trace_function_call(log_path)
def identity(value: T) -> T:
return value
F = TypeVar('F', bound=Callable[..., Any])
def deprecate_kwargs(
*kws: str,
is_deprecated: Union[bool, Callable[[], bool]] = True,
additional_message: Optional[str] = None) -> Callable[[F], F]:
deprecated_kws = set(kws)
if not callable(is_deprecated):
is_deprecated = partial(identity, is_deprecated)
def wrapper(fn: F) -> F:
@wraps(fn)
def inner(*args, **kwargs):
if is_deprecated():
deprecated_kwargs = kwargs.keys() & deprecated_kws
if deprecated_kwargs:
msg = (
f"The keyword arguments {deprecated_kwargs} are "
"deprecated and will be removed in a future update.")
if additional_message is not None:
msg += f" {additional_message}"
warnings.warn(
DeprecationWarning(msg),
stacklevel=3, # The inner function takes up one level
)
return fn(*args, **kwargs)
return inner # type: ignore
return wrapper