[Core] Consolidate prompt arguments to LLM engines (#4328)
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
290f4ada2b
commit
5ae5ed1e60
@ -63,9 +63,9 @@ steps:
|
||||
mirror_hardwares: [amd]
|
||||
|
||||
commands:
|
||||
# these tests have to be separated, because each one will allocate all posible GPU memory
|
||||
- pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py
|
||||
- pytest -v -s entrypoints/test_server_oot_registration.py
|
||||
- pytest -v -s test_inputs.py
|
||||
- pytest -v -s entrypoints -m llm
|
||||
- pytest -v -s entrypoints -m openai
|
||||
|
||||
- label: Examples Test
|
||||
working_dir: "/vllm-workspace/examples"
|
||||
@ -110,6 +110,9 @@ steps:
|
||||
mirror_hardwares: [amd]
|
||||
command: pytest -v -s test_logits_processor.py
|
||||
|
||||
- label: Utils Test
|
||||
command: pytest -v -s test_utils.py
|
||||
|
||||
- label: Worker Test
|
||||
mirror_hardwares: [amd]
|
||||
command: pytest -v -s worker
|
||||
|
@ -3,13 +3,14 @@ import argparse
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.inputs import PromptStrictInputs
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
|
||||
|
||||
@ -48,7 +49,9 @@ def main(args: argparse.Namespace):
|
||||
dummy_prompt_token_ids = np.random.randint(10000,
|
||||
size=(args.batch_size,
|
||||
args.input_len))
|
||||
dummy_prompt_token_ids = dummy_prompt_token_ids.tolist()
|
||||
dummy_inputs: List[PromptStrictInputs] = [{
|
||||
"prompt_token_ids": batch
|
||||
} for batch in dummy_prompt_token_ids.tolist()]
|
||||
|
||||
def run_to_completion(profile_dir: Optional[str] = None):
|
||||
if profile_dir:
|
||||
@ -59,13 +62,13 @@ def main(args: argparse.Namespace):
|
||||
],
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
str(profile_dir))) as p:
|
||||
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
|
||||
llm.generate(dummy_inputs,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=False)
|
||||
print(p.key_averages())
|
||||
else:
|
||||
start_time = time.perf_counter()
|
||||
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
|
||||
llm.generate(dummy_inputs,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=False)
|
||||
end_time = time.perf_counter()
|
||||
|
@ -1,5 +1,5 @@
|
||||
LLM Class
|
||||
==========
|
||||
=========
|
||||
|
||||
.. autoclass:: vllm.LLM
|
||||
:members:
|
14
docs/source/dev/offline_inference/llm_inputs.rst
Normal file
14
docs/source/dev/offline_inference/llm_inputs.rst
Normal file
@ -0,0 +1,14 @@
|
||||
LLM Inputs
|
||||
==========
|
||||
|
||||
.. autodata:: vllm.inputs.PromptStrictInputs
|
||||
|
||||
.. autoclass:: vllm.inputs.TextPrompt
|
||||
:show-inheritance:
|
||||
:members:
|
||||
:member-order: bysource
|
||||
|
||||
.. autoclass:: vllm.inputs.TokensPrompt
|
||||
:show-inheritance:
|
||||
:members:
|
||||
:member-order: bysource
|
8
docs/source/dev/offline_inference/offline_index.rst
Normal file
8
docs/source/dev/offline_inference/offline_index.rst
Normal file
@ -0,0 +1,8 @@
|
||||
Offline Inference
|
||||
=================================
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
llm
|
||||
llm_inputs
|
@ -68,13 +68,6 @@ Documentation
|
||||
getting_started/quickstart
|
||||
getting_started/examples/examples_index
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Offline Inference
|
||||
|
||||
offline_inference/llm
|
||||
offline_inference/sampling_params
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Serving
|
||||
@ -108,7 +101,9 @@ Documentation
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: Developer Documentation
|
||||
|
||||
|
||||
dev/sampling_params
|
||||
dev/offline_inference/offline_index
|
||||
dev/engine/engine_index
|
||||
dev/kernel/paged_attention
|
||||
dev/dockerfile/dockerfile
|
||||
|
@ -48,7 +48,7 @@ completion = client.chat.completions.create(
|
||||
```
|
||||
|
||||
### Extra Parameters for Chat API
|
||||
The following [sampling parameters (click through to see documentation)](../offline_inference/sampling_params.rst) are supported.
|
||||
The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.
|
||||
|
||||
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
|
||||
:language: python
|
||||
@ -65,7 +65,7 @@ The following extra parameters are supported:
|
||||
```
|
||||
|
||||
### Extra Parameters for Completions API
|
||||
The following [sampling parameters (click through to see documentation)](../offline_inference/sampling_params.rst) are supported.
|
||||
The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported.
|
||||
|
||||
```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py
|
||||
:language: python
|
||||
|
@ -23,11 +23,15 @@ def run_llava_pixel_values():
|
||||
"\nUSER: What is the content of this image?\nASSISTANT:")
|
||||
|
||||
# This should be provided by another online or offline component.
|
||||
images = torch.load("images/stop_sign_pixel_values.pt")
|
||||
image = torch.load("images/stop_sign_pixel_values.pt")
|
||||
|
||||
outputs = llm.generate({
|
||||
"prompt":
|
||||
prompt,
|
||||
"multi_modal_data":
|
||||
MultiModalData(type=MultiModalData.Type.IMAGE, data=image),
|
||||
})
|
||||
|
||||
outputs = llm.generate(prompt,
|
||||
multi_modal_data=MultiModalData(
|
||||
type=MultiModalData.Type.IMAGE, data=images))
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
print(generated_text)
|
||||
@ -46,11 +50,14 @@ def run_llava_image_features():
|
||||
"\nUSER: What is the content of this image?\nASSISTANT:")
|
||||
|
||||
# This should be provided by another online or offline component.
|
||||
images = torch.load("images/stop_sign_image_features.pt")
|
||||
image = torch.load("images/stop_sign_image_features.pt")
|
||||
|
||||
outputs = llm.generate(prompt,
|
||||
multi_modal_data=MultiModalData(
|
||||
type=MultiModalData.Type.IMAGE, data=images))
|
||||
outputs = llm.generate({
|
||||
"prompt":
|
||||
prompt,
|
||||
"multi_modal_data":
|
||||
MultiModalData(type=MultiModalData.Type.IMAGE, data=image),
|
||||
})
|
||||
for o in outputs:
|
||||
generated_text = o.outputs[0].text
|
||||
print(generated_text)
|
||||
|
@ -65,3 +65,10 @@ skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build"
|
||||
[tool.isort]
|
||||
use_parentheses = true
|
||||
skip_gitignore = true
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
markers = [
|
||||
"skip_global_cleanup",
|
||||
"llm: run tests for vLLM API only",
|
||||
"openai: run tests for OpenAI API only",
|
||||
]
|
||||
|
@ -25,7 +25,7 @@ class MockEngine:
|
||||
return [RequestOutput(
|
||||
request_id=self.request_id)] if self.request_id else []
|
||||
|
||||
async def encode_request_async(self, *args, **kwargs):
|
||||
async def process_model_inputs_async(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def generate(self, request_id):
|
||||
|
@ -29,7 +29,7 @@ def server():
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@pytest.fixture(scope="module")
|
||||
def client():
|
||||
client = openai.AsyncOpenAI(
|
||||
base_url="http://localhost:8000/v1",
|
||||
|
@ -12,6 +12,7 @@ from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer,
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
|
||||
from vllm.distributed import destroy_model_parallel
|
||||
from vllm.inputs import PromptInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import MultiModalData
|
||||
|
||||
@ -402,12 +403,22 @@ class VllmRunner:
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
if images is not None:
|
||||
assert len(prompts) == images.shape[0]
|
||||
req_outputs = self.model.generate(
|
||||
prompts,
|
||||
sampling_params=sampling_params,
|
||||
multi_modal_data=MultiModalData(type=MultiModalData.Type.IMAGE,
|
||||
data=images)
|
||||
if images is not None else None)
|
||||
|
||||
prompt_inputs: List[PromptInputs] = []
|
||||
for i, prompt in enumerate(prompts):
|
||||
image = None if images is None else images[i:i + 1]
|
||||
mm_data = None if image is None else MultiModalData(
|
||||
type=MultiModalData.Type.IMAGE,
|
||||
data=image,
|
||||
)
|
||||
|
||||
prompt_inputs.append({
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": mm_data,
|
||||
})
|
||||
|
||||
req_outputs = self.model.generate(prompt_inputs,
|
||||
sampling_params=sampling_params)
|
||||
outputs = []
|
||||
for req_output in req_outputs:
|
||||
prompt_str = req_output.prompt
|
||||
|
@ -133,8 +133,11 @@ def test_append_slot_cow():
|
||||
|
||||
# Allocate prompt to gpu block. There is one slot left in the block.
|
||||
prompt = Sequence(seq_id=1,
|
||||
prompt="one two three",
|
||||
prompt_token_ids=[1, 2, 3],
|
||||
inputs={
|
||||
"prompt": "one two three",
|
||||
"prompt_token_ids": [1, 2, 3],
|
||||
"multi_modal_data": None
|
||||
},
|
||||
block_size=block_size)
|
||||
|
||||
# Fork the sequence, such that a COW will be required when we append a new
|
||||
@ -304,7 +307,13 @@ def test_sliding_window_multi_seq():
|
||||
|
||||
assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks
|
||||
|
||||
parent = Sequence(1, "one two three", [0, 1, 2], block_size)
|
||||
parent = Sequence(seq_id=1,
|
||||
inputs={
|
||||
"prompt": "one two three",
|
||||
"prompt_token_ids": [0, 1, 2],
|
||||
"multi_modal_data": None
|
||||
},
|
||||
block_size=block_size)
|
||||
seq_group = SequenceGroup(request_id="1",
|
||||
seqs=[parent],
|
||||
arrival_time=time.time(),
|
||||
|
@ -21,7 +21,13 @@ def create_dummy_prompt(
|
||||
# and prompt "0 ... block_size".
|
||||
prompt_tokens = list(range(prompt_length))
|
||||
prompt_str = " ".join([str(t) for t in prompt_tokens])
|
||||
prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size)
|
||||
prompt = Sequence(int(request_id),
|
||||
inputs={
|
||||
"prompt": prompt_str,
|
||||
"prompt_token_ids": prompt_tokens,
|
||||
"multi_modal_data": None,
|
||||
},
|
||||
block_size=block_size)
|
||||
seq_group = SequenceGroup(request_id=request_id,
|
||||
seqs=[prompt],
|
||||
arrival_time=time.time(),
|
||||
@ -51,8 +57,11 @@ def create_seq_group(
|
||||
for seq_id_offset, output_len in enumerate(seq_output_lens):
|
||||
seq = Sequence(
|
||||
seq_id=seq_id_start + seq_id_offset,
|
||||
prompt="",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
inputs={
|
||||
"prompt": "",
|
||||
"prompt_token_ids": prompt_token_ids,
|
||||
"multi_modal_data": None,
|
||||
},
|
||||
block_size=16,
|
||||
)
|
||||
|
||||
|
@ -14,7 +14,7 @@ def test_skip_tokenizer_initialization(model: str):
|
||||
with pytest.raises(ValueError) as err:
|
||||
llm.generate("abc", sampling_params)
|
||||
assert "prompts must be None if" in str(err.value)
|
||||
outputs = llm.generate(prompt_token_ids=[[1, 2, 3]],
|
||||
outputs = llm.generate({"prompt_token_ids": [1, 2, 3]},
|
||||
sampling_params=sampling_params)
|
||||
assert len(outputs) > 0
|
||||
completions = outputs[0].outputs
|
||||
|
@ -1,11 +1,15 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
CHAT_TEMPLATE = "Dummy chat template for testing {}"
|
||||
|
||||
pytestmark = pytest.mark.openai
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockModelConfig:
|
||||
|
@ -52,6 +52,8 @@ TEST_SCHEMA = {
|
||||
TEST_REGEX = (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
|
||||
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
|
||||
|
||||
pytestmark = pytest.mark.openai
|
||||
|
||||
|
||||
def test_guided_logits_processors():
|
||||
"""Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor."""
|
||||
|
144
tests/entrypoints/test_llm_encode.py
Normal file
144
tests/entrypoints/test_llm_encode.py
Normal file
@ -0,0 +1,144 @@
|
||||
import weakref
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, EmbeddingRequestOutput, PoolingParams
|
||||
|
||||
from ..conftest import cleanup
|
||||
|
||||
MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
|
||||
|
||||
PROMPTS = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
TOKEN_IDS = [
|
||||
# Using ID={0, 1, 2, 3} results in NaN values,
|
||||
# so we add this offset of 1000
|
||||
[1000],
|
||||
[1000, 1001],
|
||||
[1000, 1002, 1001],
|
||||
[1000, 1003, 1001, 1002],
|
||||
]
|
||||
|
||||
pytestmark = pytest.mark.llm
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def llm():
|
||||
# pytest caches the fixture so we use weakref.proxy to
|
||||
# enable garbage collection
|
||||
llm = LLM(model=MODEL_NAME,
|
||||
max_num_batched_tokens=32768,
|
||||
tensor_parallel_size=1,
|
||||
gpu_memory_utilization=0.75,
|
||||
enforce_eager=True)
|
||||
|
||||
with llm.deprecate_legacy_api():
|
||||
yield weakref.proxy(llm)
|
||||
|
||||
del llm
|
||||
|
||||
cleanup()
|
||||
|
||||
|
||||
def assert_outputs_equal(o1: List[EmbeddingRequestOutput],
|
||||
o2: List[EmbeddingRequestOutput]):
|
||||
assert [o.outputs for o in o1] == [o.outputs for o in o2]
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize('prompt', PROMPTS)
|
||||
def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
|
||||
pooling_params = PoolingParams()
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'prompts'"):
|
||||
v1_output = llm.encode(prompts=prompt, pooling_params=pooling_params)
|
||||
|
||||
v2_output = llm.encode(prompt, pooling_params=pooling_params)
|
||||
assert_outputs_equal(v1_output, v2_output)
|
||||
|
||||
v2_output = llm.encode({"prompt": prompt}, pooling_params=pooling_params)
|
||||
assert_outputs_equal(v1_output, v2_output)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
|
||||
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
|
||||
prompt_token_ids):
|
||||
pooling_params = PoolingParams()
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
|
||||
v1_output = llm.encode(prompt_token_ids=prompt_token_ids,
|
||||
pooling_params=pooling_params)
|
||||
|
||||
v2_output = llm.encode({"prompt_token_ids": prompt_token_ids},
|
||||
pooling_params=pooling_params)
|
||||
assert_outputs_equal(v1_output, v2_output)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
|
||||
pooling_params = PoolingParams()
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'prompts'"):
|
||||
v1_output = llm.encode(prompts=PROMPTS, pooling_params=pooling_params)
|
||||
|
||||
v2_output = llm.encode(PROMPTS, pooling_params=pooling_params)
|
||||
assert_outputs_equal(v1_output, v2_output)
|
||||
|
||||
v2_output = llm.encode(
|
||||
[{
|
||||
"prompt": p
|
||||
} for p in PROMPTS],
|
||||
pooling_params=pooling_params,
|
||||
)
|
||||
assert_outputs_equal(v1_output, v2_output)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
|
||||
pooling_params = PoolingParams()
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
|
||||
v1_output = llm.encode(prompt_token_ids=TOKEN_IDS,
|
||||
pooling_params=pooling_params)
|
||||
|
||||
v2_output = llm.encode(
|
||||
[{
|
||||
"prompt_token_ids": p
|
||||
} for p in TOKEN_IDS],
|
||||
pooling_params=pooling_params,
|
||||
)
|
||||
assert_outputs_equal(v1_output, v2_output)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_multiple_pooling_params(llm: LLM):
|
||||
pooling_params = [
|
||||
PoolingParams(),
|
||||
PoolingParams(),
|
||||
PoolingParams(),
|
||||
PoolingParams(),
|
||||
]
|
||||
|
||||
# Multiple PoolingParams should be matched with each prompt
|
||||
outputs = llm.encode(PROMPTS, pooling_params=pooling_params)
|
||||
assert len(PROMPTS) == len(outputs)
|
||||
|
||||
# Exception raised, if the size of params does not match the size of prompts
|
||||
with pytest.raises(ValueError):
|
||||
outputs = llm.encode(PROMPTS, pooling_params=pooling_params[:3])
|
||||
|
||||
# Single PoolingParams should be applied to every prompt
|
||||
single_pooling_params = PoolingParams()
|
||||
outputs = llm.encode(PROMPTS, pooling_params=single_pooling_params)
|
||||
assert len(PROMPTS) == len(outputs)
|
||||
|
||||
# pooling_params is None, default params should be applied
|
||||
outputs = llm.encode(PROMPTS, pooling_params=None)
|
||||
assert len(PROMPTS) == len(outputs)
|
@ -1,21 +1,124 @@
|
||||
import weakref
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm import LLM, RequestOutput, SamplingParams
|
||||
|
||||
from ..conftest import cleanup
|
||||
|
||||
MODEL_NAME = "facebook/opt-125m"
|
||||
|
||||
PROMPTS = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
|
||||
TOKEN_IDS = [
|
||||
[0],
|
||||
[0, 1],
|
||||
[0, 2, 1],
|
||||
[0, 3, 1, 2],
|
||||
]
|
||||
|
||||
pytestmark = pytest.mark.llm
|
||||
|
||||
|
||||
def test_multiple_sampling_params():
|
||||
|
||||
llm = LLM(model="facebook/opt-125m",
|
||||
@pytest.fixture(scope="module")
|
||||
def llm():
|
||||
# pytest caches the fixture so we use weakref.proxy to
|
||||
# enable garbage collection
|
||||
llm = LLM(model=MODEL_NAME,
|
||||
max_num_batched_tokens=4096,
|
||||
tensor_parallel_size=1)
|
||||
tensor_parallel_size=1,
|
||||
gpu_memory_utilization=0.10,
|
||||
enforce_eager=True)
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
with llm.deprecate_legacy_api():
|
||||
yield weakref.proxy(llm)
|
||||
|
||||
del llm
|
||||
|
||||
cleanup()
|
||||
|
||||
|
||||
def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]):
|
||||
assert [o.outputs for o in o1] == [o.outputs for o in o2]
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize('prompt', PROMPTS)
|
||||
def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
|
||||
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'prompts'"):
|
||||
v1_output = llm.generate(prompts=prompt,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
v2_output = llm.generate(prompt, sampling_params=sampling_params)
|
||||
assert_outputs_equal(v1_output, v2_output)
|
||||
|
||||
v2_output = llm.generate({"prompt": prompt},
|
||||
sampling_params=sampling_params)
|
||||
assert_outputs_equal(v1_output, v2_output)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
|
||||
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
|
||||
prompt_token_ids):
|
||||
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
|
||||
v1_output = llm.generate(prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
v2_output = llm.generate({"prompt_token_ids": prompt_token_ids},
|
||||
sampling_params=sampling_params)
|
||||
assert_outputs_equal(v1_output, v2_output)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
|
||||
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'prompts'"):
|
||||
v1_output = llm.generate(prompts=PROMPTS,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
v2_output = llm.generate(PROMPTS, sampling_params=sampling_params)
|
||||
assert_outputs_equal(v1_output, v2_output)
|
||||
|
||||
v2_output = llm.generate(
|
||||
[{
|
||||
"prompt": p
|
||||
} for p in PROMPTS],
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
assert_outputs_equal(v1_output, v2_output)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
|
||||
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"):
|
||||
v1_output = llm.generate(prompt_token_ids=TOKEN_IDS,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
v2_output = llm.generate(
|
||||
[{
|
||||
"prompt_token_ids": p
|
||||
} for p in TOKEN_IDS],
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
assert_outputs_equal(v1_output, v2_output)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_multiple_sampling_params(llm: LLM):
|
||||
sampling_params = [
|
||||
SamplingParams(temperature=0.01, top_p=0.95),
|
||||
SamplingParams(temperature=0.3, top_p=0.95),
|
||||
@ -24,18 +127,18 @@ def test_multiple_sampling_params():
|
||||
]
|
||||
|
||||
# Multiple SamplingParams should be matched with each prompt
|
||||
outputs = llm.generate(prompts, sampling_params=sampling_params)
|
||||
assert len(prompts) == len(outputs)
|
||||
outputs = llm.generate(PROMPTS, sampling_params=sampling_params)
|
||||
assert len(PROMPTS) == len(outputs)
|
||||
|
||||
# Exception raised, if the size of params does not match the size of prompts
|
||||
with pytest.raises(ValueError):
|
||||
outputs = llm.generate(prompts, sampling_params=sampling_params[:3])
|
||||
outputs = llm.generate(PROMPTS, sampling_params=sampling_params[:3])
|
||||
|
||||
# Single SamplingParams should be applied to every prompt
|
||||
single_sampling_params = SamplingParams(temperature=0.3, top_p=0.95)
|
||||
outputs = llm.generate(prompts, sampling_params=single_sampling_params)
|
||||
assert len(prompts) == len(outputs)
|
||||
outputs = llm.generate(PROMPTS, sampling_params=single_sampling_params)
|
||||
assert len(PROMPTS) == len(outputs)
|
||||
|
||||
# sampling_params is None, default params should be applied
|
||||
outputs = llm.generate(prompts, sampling_params=None)
|
||||
assert len(prompts) == len(outputs)
|
||||
outputs = llm.generate(PROMPTS, sampling_params=None)
|
||||
assert len(PROMPTS) == len(outputs)
|
||||
|
@ -71,7 +71,7 @@ TEST_CHOICE = [
|
||||
"Swift", "Kotlin"
|
||||
]
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
pytestmark = pytest.mark.openai
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@ -91,6 +91,8 @@ def server(zephyr_lora_files):
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--enforce-eager",
|
||||
"--gpu-memory-utilization",
|
||||
"0.75",
|
||||
# lora config below
|
||||
"--enable-lora",
|
||||
"--lora-modules",
|
||||
@ -118,9 +120,11 @@ def embedding_server(zephyr_lora_files):
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--enforce-eager",
|
||||
"--gpu-memory-utilization",
|
||||
"0.75",
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"--enforce-eager",
|
||||
])
|
||||
ray.get(server_runner.ready.remote())
|
||||
yield server_runner
|
||||
@ -136,6 +140,7 @@ def client():
|
||||
yield client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_models(server, client: openai.AsyncOpenAI):
|
||||
models = await client.models.list()
|
||||
models = models.data
|
||||
@ -147,6 +152,7 @@ async def test_check_models(server, client: openai.AsyncOpenAI):
|
||||
assert lora_models[1].id == "zephyr-lora2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras
|
||||
"model_name",
|
||||
@ -178,6 +184,7 @@ async def test_single_completion(server, client: openai.AsyncOpenAI,
|
||||
completion.choices[0].text) >= 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras
|
||||
"model_name",
|
||||
@ -199,6 +206,7 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI,
|
||||
assert choice.logprobs.top_logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# just test 1 lora hereafter
|
||||
"model_name",
|
||||
@ -243,6 +251,7 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
|
||||
assert message.content is not None and len(message.content) >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
@ -298,6 +307,7 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
|
||||
assert message.content is not None and len(message.content) >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# just test 1 lora hereafter
|
||||
"model_name",
|
||||
@ -335,6 +345,7 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI,
|
||||
assert "".join(chunks) == single_output
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# just test 1 lora hereafter
|
||||
"model_name",
|
||||
@ -385,6 +396,7 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI,
|
||||
assert "".join(chunks) == output
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# just test 1 lora hereafter
|
||||
"model_name",
|
||||
@ -438,6 +450,7 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI,
|
||||
assert texts[0] == texts[1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logits_bias(server, client: openai.AsyncOpenAI):
|
||||
prompt = "Hello, my name is"
|
||||
max_tokens = 5
|
||||
@ -485,6 +498,7 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
|
||||
assert first_response != completion.choices[0].text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_json_completion(server, client: openai.AsyncOpenAI,
|
||||
@ -507,6 +521,7 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI,
|
||||
jsonschema.validate(instance=output_json, schema=TEST_SCHEMA)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_json_chat(server, client: openai.AsyncOpenAI,
|
||||
@ -553,6 +568,7 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI,
|
||||
assert json1["age"] != json2["age"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_regex_completion(server, client: openai.AsyncOpenAI,
|
||||
@ -573,6 +589,7 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI,
|
||||
assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_regex_chat(server, client: openai.AsyncOpenAI,
|
||||
@ -610,6 +627,7 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI,
|
||||
assert ip1 != ip2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_choice_completion(server, client: openai.AsyncOpenAI,
|
||||
@ -629,6 +647,7 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI,
|
||||
assert completion.choices[i].text in TEST_CHOICE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_choice_chat(server, client: openai.AsyncOpenAI,
|
||||
@ -667,6 +686,7 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI,
|
||||
assert choice1 != choice2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI,
|
||||
@ -702,6 +722,7 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI,
|
||||
extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
|
||||
@ -732,6 +753,7 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
|
||||
for token, logprob in token_dict.items())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
|
||||
for _ in range(2):
|
||||
resp = await client.chat.completions.create(
|
||||
@ -749,6 +771,7 @@ async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
|
||||
assert loaded == {"result": 2}, loaded
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extra_fields(server, client: openai.AsyncOpenAI):
|
||||
with pytest.raises(BadRequestError) as exc_info:
|
||||
await client.chat.completions.create(
|
||||
@ -764,6 +787,7 @@ async def test_extra_fields(server, client: openai.AsyncOpenAI):
|
||||
assert "extra_forbidden" in exc_info.value.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complex_message_content(server, client: openai.AsyncOpenAI):
|
||||
resp = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
@ -783,6 +807,7 @@ async def test_complex_message_content(server, client: openai.AsyncOpenAI):
|
||||
assert content == "2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_role(server, client: openai.AsyncOpenAI):
|
||||
# Not sure how the model handles custom roles so we just check that
|
||||
# both string and complex message content are handled in the same way
|
||||
@ -813,6 +838,7 @@ async def test_custom_role(server, client: openai.AsyncOpenAI):
|
||||
assert content1 == content2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_guided_grammar(server, client: openai.AsyncOpenAI):
|
||||
simple_sql_grammar = """
|
||||
start: select_statement
|
||||
@ -847,6 +873,7 @@ number: "1" | "2"
|
||||
assert content.strip() == ground_truth
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras
|
||||
"model_name",
|
||||
@ -878,6 +905,7 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
|
||||
assert len(logprobs.tokens) > 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_seed(server, client: openai.AsyncOpenAI):
|
||||
for seed in [
|
||||
torch.iinfo(torch.long).min - 1,
|
||||
@ -897,6 +925,7 @@ async def test_long_seed(server, client: openai.AsyncOpenAI):
|
||||
or "less_than_equal" in exc_info.value.message)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[EMBEDDING_MODEL_NAME],
|
||||
@ -935,6 +964,7 @@ async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI,
|
||||
assert embeddings.usage.total_tokens == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[EMBEDDING_MODEL_NAME],
|
||||
|
@ -1,7 +1,7 @@
|
||||
import multiprocessing
|
||||
import sys
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from openai import OpenAI, OpenAIError
|
||||
|
||||
@ -10,6 +10,8 @@ from vllm.model_executor.models.opt import OPTForCausalLM
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.utils import get_open_port
|
||||
|
||||
pytestmark = pytest.mark.openai
|
||||
|
||||
|
||||
class MyOPTForCausalLM(OPTForCausalLM):
|
||||
|
||||
@ -26,15 +28,16 @@ def server_function(port):
|
||||
# register our dummy model
|
||||
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
|
||||
sys.argv = ["placeholder.py"] + \
|
||||
("--model facebook/opt-125m --dtype"
|
||||
f" float32 --api-key token-abc123 --port {port}").split()
|
||||
("--model facebook/opt-125m --gpu-memory-utilization 0.10 "
|
||||
f"--dtype float32 --api-key token-abc123 --port {port}").split()
|
||||
import runpy
|
||||
runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__')
|
||||
|
||||
|
||||
def test_oot_registration_for_api_server():
|
||||
port = get_open_port()
|
||||
server = multiprocessing.Process(target=server_function, args=(port, ))
|
||||
ctx = torch.multiprocessing.get_context()
|
||||
server = ctx.Process(target=server_function, args=(port, ))
|
||||
server.start()
|
||||
client = OpenAI(
|
||||
base_url=f"http://localhost:{port}/v1",
|
||||
|
@ -86,20 +86,18 @@ def generate(
|
||||
|
||||
|
||||
def batched_generate(
|
||||
llm,
|
||||
llm: vllm.LLM,
|
||||
inputs: List[Tuple[str, SamplingParams, Optional[LoRARequest]]],
|
||||
):
|
||||
for input in inputs:
|
||||
prompt, sampling_param, lora_req = input
|
||||
requests_data = llm._validate_and_prepare_requests(
|
||||
# Add requests to the engine and run the engine
|
||||
llm._validate_and_add_requests(
|
||||
prompt,
|
||||
sampling_param,
|
||||
lora_request=lora_req,
|
||||
)
|
||||
|
||||
# Add requests to the engine and run the engine
|
||||
for request_data in requests_data:
|
||||
llm._add_request(**request_data)
|
||||
outputs = llm._run_engine(use_tqdm=True)
|
||||
return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))]
|
||||
|
||||
|
@ -35,28 +35,25 @@ def test_logits_processor_force_generate(
|
||||
|
||||
# test logits_processors when prompt_logprobs is not None
|
||||
vllm_model.model._add_request(
|
||||
prompt=example_prompts[0],
|
||||
example_prompts[0],
|
||||
params=params_with_logprobs,
|
||||
prompt_token_ids=None,
|
||||
)
|
||||
|
||||
# test prompt_logprobs is not None
|
||||
vllm_model.model._add_request(
|
||||
prompt=example_prompts[1],
|
||||
example_prompts[1],
|
||||
params=SamplingParams(
|
||||
prompt_logprobs=3,
|
||||
max_tokens=max_tokens,
|
||||
),
|
||||
prompt_token_ids=None,
|
||||
)
|
||||
|
||||
# test grouped requests
|
||||
vllm_model.model._add_request(
|
||||
prompt=example_prompts[2],
|
||||
example_prompts[2],
|
||||
params=SamplingParams(max_tokens=max_tokens),
|
||||
prompt_token_ids=None,
|
||||
)
|
||||
|
||||
outputs = vllm_model.model._run_engine(False)
|
||||
outputs = vllm_model.model._run_engine(use_tqdm=False)
|
||||
|
||||
assert outputs[0].outputs[0].text == enforced_answers * repeat_times
|
||||
|
@ -57,11 +57,7 @@ def test_random_sample_with_seed(
|
||||
sampling_params_seed_1,
|
||||
sampling_params_seed_2,
|
||||
):
|
||||
llm._add_request(
|
||||
prompt=prompt,
|
||||
prompt_token_ids=None,
|
||||
params=params,
|
||||
)
|
||||
llm._add_request(prompt, params=params)
|
||||
|
||||
results = llm._run_engine(use_tqdm=False)
|
||||
all_outputs = [[out.token_ids for out in output.outputs]
|
||||
|
@ -70,8 +70,15 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
|
||||
for prompt in prompts:
|
||||
hashes[-1].append([])
|
||||
prompt_token_ids = tokenizer.encode(prompt)
|
||||
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
|
||||
tokenizer.tokenizer.eos_token_id, lora_request)
|
||||
seq = Sequence(seq_id,
|
||||
inputs={
|
||||
"prompt": prompt,
|
||||
"prompt_token_ids": prompt_token_ids,
|
||||
"multi_modal_data": None,
|
||||
},
|
||||
block_size=block_size,
|
||||
eos_token_id=tokenizer.tokenizer.eos_token_id,
|
||||
lora_request=lora_request)
|
||||
|
||||
num_blocks = len(prompt_token_ids) // block_size
|
||||
for idx in range(num_blocks):
|
||||
|
53
tests/test_inputs.py
Normal file
53
tests/test_inputs.py
Normal file
@ -0,0 +1,53 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.inputs import parse_and_batch_prompt
|
||||
|
||||
STRING_INPUTS = [
|
||||
'',
|
||||
'foo',
|
||||
'foo bar',
|
||||
'foo baz bar',
|
||||
'foo bar qux baz',
|
||||
]
|
||||
|
||||
TOKEN_INPUTS = [
|
||||
[-1],
|
||||
[1],
|
||||
[1, 2],
|
||||
[1, 3, 4],
|
||||
[1, 2, 4, 3],
|
||||
]
|
||||
|
||||
INPUTS_SLICES = [
|
||||
slice(None, None, -1),
|
||||
slice(None, None, 2),
|
||||
slice(None, None, -2),
|
||||
]
|
||||
|
||||
|
||||
def test_parse_single_batch_empty():
|
||||
with pytest.raises(ValueError, match="at least one prompt"):
|
||||
parse_and_batch_prompt([])
|
||||
|
||||
with pytest.raises(ValueError, match="at least one prompt"):
|
||||
parse_and_batch_prompt([[]])
|
||||
|
||||
|
||||
@pytest.mark.parametrize('string_input', STRING_INPUTS)
|
||||
def test_parse_single_batch_string_consistent(string_input: str):
|
||||
assert parse_and_batch_prompt(string_input) \
|
||||
== parse_and_batch_prompt([string_input])
|
||||
|
||||
|
||||
@pytest.mark.parametrize('token_input', TOKEN_INPUTS)
|
||||
def test_parse_single_batch_token_consistent(token_input: List[int]):
|
||||
assert parse_and_batch_prompt(token_input) \
|
||||
== parse_and_batch_prompt([token_input])
|
||||
|
||||
|
||||
@pytest.mark.parametrize('inputs_slice', INPUTS_SLICES)
|
||||
def test_parse_single_batch_string_slice(inputs_slice: slice):
|
||||
assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \
|
||||
== parse_and_batch_prompt(STRING_INPUTS[inputs_slice])
|
63
tests/test_utils.py
Normal file
63
tests/test_utils.py
Normal file
@ -0,0 +1,63 @@
|
||||
import pytest
|
||||
|
||||
from vllm.utils import deprecate_kwargs
|
||||
|
||||
from .utils import error_on_warning
|
||||
|
||||
|
||||
def test_deprecate_kwargs_always():
|
||||
|
||||
@deprecate_kwargs("old_arg", is_deprecated=True)
|
||||
def dummy(*, old_arg: object = None, new_arg: object = None):
|
||||
pass
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'old_arg'"):
|
||||
dummy(old_arg=1)
|
||||
|
||||
with error_on_warning():
|
||||
dummy(new_arg=1)
|
||||
|
||||
|
||||
def test_deprecate_kwargs_never():
|
||||
|
||||
@deprecate_kwargs("old_arg", is_deprecated=False)
|
||||
def dummy(*, old_arg: object = None, new_arg: object = None):
|
||||
pass
|
||||
|
||||
with error_on_warning():
|
||||
dummy(old_arg=1)
|
||||
|
||||
with error_on_warning():
|
||||
dummy(new_arg=1)
|
||||
|
||||
|
||||
def test_deprecate_kwargs_dynamic():
|
||||
is_deprecated = True
|
||||
|
||||
@deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated)
|
||||
def dummy(*, old_arg: object = None, new_arg: object = None):
|
||||
pass
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="'old_arg'"):
|
||||
dummy(old_arg=1)
|
||||
|
||||
with error_on_warning():
|
||||
dummy(new_arg=1)
|
||||
|
||||
is_deprecated = False
|
||||
|
||||
with error_on_warning():
|
||||
dummy(old_arg=1)
|
||||
|
||||
with error_on_warning():
|
||||
dummy(new_arg=1)
|
||||
|
||||
|
||||
def test_deprecate_kwargs_additional_message():
|
||||
|
||||
@deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd")
|
||||
def dummy(*, old_arg: object = None, new_arg: object = None):
|
||||
pass
|
||||
|
||||
with pytest.warns(DeprecationWarning, match="abcd"):
|
||||
dummy(old_arg=1)
|
@ -123,8 +123,11 @@ def create_sequence(prompt_token_ids=None):
|
||||
prompt_token_ids = prompt_token_ids or [1]
|
||||
return Sequence(
|
||||
seq_id=0,
|
||||
prompt="<s>",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
inputs={
|
||||
"prompt": "<s>",
|
||||
"prompt_token_ids": prompt_token_ids,
|
||||
"multi_modal_data": None,
|
||||
},
|
||||
block_size=16,
|
||||
)
|
||||
|
||||
|
@ -2,6 +2,8 @@ import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
|
||||
import ray
|
||||
import requests
|
||||
@ -87,3 +89,15 @@ def multi_process_tensor_parallel(
|
||||
ray.get(refs)
|
||||
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def error_on_warning():
|
||||
"""
|
||||
Within the scope of this context manager, tests will fail if any warning
|
||||
is emitted.
|
||||
"""
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error")
|
||||
|
||||
yield
|
||||
|
@ -5,6 +5,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.inputs import PromptStrictInputs, TextPrompt, TokensPrompt
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.outputs import (CompletionOutput, EmbeddingOutput,
|
||||
EmbeddingRequestOutput, RequestOutput)
|
||||
@ -16,6 +17,9 @@ __version__ = "0.4.2"
|
||||
__all__ = [
|
||||
"LLM",
|
||||
"ModelRegistry",
|
||||
"PromptStrictInputs",
|
||||
"TextPrompt",
|
||||
"TokensPrompt",
|
||||
"SamplingParams",
|
||||
"RequestOutput",
|
||||
"CompletionOutput",
|
||||
|
@ -12,12 +12,13 @@ from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster, ray
|
||||
from vllm.inputs import LLMInputs, PromptInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import ExecuteModelRequest, MultiModalData, SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -244,64 +245,69 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
|
||||
return request_outputs
|
||||
|
||||
async def encode_request_async(
|
||||
async def process_model_inputs_async(
|
||||
self,
|
||||
request_id: str, # pylint: disable=unused-argument
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
request_id: str,
|
||||
inputs: PromptInputs,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
):
|
||||
if prompt_token_ids is None:
|
||||
assert prompt is not None
|
||||
prompt_token_ids = await self.tokenizer.encode_async(
|
||||
) -> LLMInputs:
|
||||
if isinstance(inputs, str):
|
||||
inputs = {"prompt": inputs}
|
||||
|
||||
if "prompt_token_ids" not in inputs:
|
||||
tokenizer = self.get_tokenizer_group("prompts must be None if "
|
||||
"skip_tokenizer_init is True")
|
||||
|
||||
prompt_token_ids = await tokenizer.encode_async(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
prompt=inputs["prompt"],
|
||||
lora_request=lora_request)
|
||||
return prompt_token_ids
|
||||
else:
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
|
||||
return LLMInputs(prompt_token_ids=prompt_token_ids,
|
||||
prompt=inputs.get("prompt"),
|
||||
multi_modal_data=inputs.get("multi_modal_data"))
|
||||
|
||||
async def add_request_async(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
inputs: PromptInputs,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> None:
|
||||
if lora_request is not None and not self.lora_config:
|
||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||
"not enabled!")
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
prompt_token_ids = await self.encode_request_async(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
lora_request=lora_request)
|
||||
|
||||
return self.add_request(request_id,
|
||||
prompt=prompt,
|
||||
params=params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
multi_modal_data=multi_modal_data)
|
||||
processed_inputs = await self.process_model_inputs_async(
|
||||
request_id=request_id, inputs=inputs, lora_request=lora_request)
|
||||
|
||||
self._add_processed_request(
|
||||
request_id=request_id,
|
||||
processed_inputs=processed_inputs,
|
||||
params=params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
async def check_health_async(self) -> None:
|
||||
self.model_executor.check_health()
|
||||
|
||||
|
||||
class AsyncLLMEngine:
|
||||
"""An asynchronous wrapper for LLMEngine.
|
||||
"""An asynchronous wrapper for :class:`LLMEngine`.
|
||||
|
||||
This class is used to wrap the LLMEngine class to make it asynchronous. It
|
||||
uses asyncio to create a background loop that keeps processing incoming
|
||||
requests. The LLMEngine is kicked by the generate method when there
|
||||
are requests in the waiting queue. The generate method yields the outputs
|
||||
from the LLMEngine to the caller.
|
||||
This class is used to wrap the :class:`LLMEngine` class to make it
|
||||
asynchronous. It uses asyncio to create a background loop that keeps
|
||||
processing incoming requests. The :class:`LLMEngine` is kicked by the
|
||||
generate method when there are requests in the waiting queue. The generate
|
||||
method yields the outputs from the :class:`LLMEngine` to the caller.
|
||||
|
||||
NOTE: For the comprehensive list of arguments, see `LLMEngine`.
|
||||
NOTE: For the comprehensive list of arguments, see :class:`LLMEngine`.
|
||||
|
||||
Args:
|
||||
worker_use_ray: Whether to use Ray for model workers. Required for
|
||||
@ -315,8 +321,8 @@ class AsyncLLMEngine:
|
||||
being printed in log.
|
||||
start_engine_loop: If True, the background task to run the engine
|
||||
will be automatically started in the generate call.
|
||||
*args: Arguments for LLMEngine.
|
||||
*kwargs: Arguments for LLMEngine.
|
||||
*args: Arguments for :class:`LLMEngine`.
|
||||
**kwargs: Arguments for :class:`LLMEngine`.
|
||||
"""
|
||||
|
||||
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
|
||||
@ -526,22 +532,26 @@ class AsyncLLMEngine:
|
||||
async def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
inputs: PromptInputs,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> AsyncStream:
|
||||
if self.log_requests:
|
||||
shortened_prompt = prompt
|
||||
shortened_token_ids = prompt_token_ids
|
||||
if self.max_log_len is not None:
|
||||
if isinstance(inputs, str):
|
||||
shortened_prompt = inputs
|
||||
shortened_token_ids = None
|
||||
else:
|
||||
shortened_prompt = inputs.get("prompt")
|
||||
shortened_token_ids = inputs.get("prompt_token_ids")
|
||||
|
||||
max_log_len = self.max_log_len
|
||||
if max_log_len is not None:
|
||||
if shortened_prompt is not None:
|
||||
shortened_prompt = shortened_prompt[:self.max_log_len]
|
||||
shortened_prompt = shortened_prompt[:max_log_len]
|
||||
if shortened_token_ids is not None:
|
||||
shortened_token_ids = shortened_token_ids[:self.
|
||||
max_log_len]
|
||||
shortened_token_ids = shortened_token_ids[:max_log_len]
|
||||
|
||||
logger.info(
|
||||
"Received request %s: prompt: %r, "
|
||||
"params: %s, prompt_token_ids: %s, "
|
||||
@ -562,39 +572,33 @@ class AsyncLLMEngine:
|
||||
arrival_time = time.time()
|
||||
|
||||
if self.engine_use_ray:
|
||||
prompt_token_ids = await (
|
||||
self.engine.encode_request_async.remote( # type: ignore
|
||||
processed_inputs = await self.engine.process_model_inputs_async \
|
||||
.remote( # type: ignore
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
lora_request=lora_request))
|
||||
inputs=inputs,
|
||||
lora_request=lora_request)
|
||||
else:
|
||||
prompt_token_ids = await self.engine.encode_request_async(
|
||||
processed_inputs = await self.engine.process_model_inputs_async(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
inputs=inputs,
|
||||
lora_request=lora_request)
|
||||
|
||||
stream = self._request_tracker.add_request(
|
||||
request_id,
|
||||
prompt=prompt,
|
||||
inputs=processed_inputs,
|
||||
params=params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
|
||||
return stream
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: Optional[str],
|
||||
inputs: PromptInputs,
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None
|
||||
) -> AsyncIterator[RequestOutput]:
|
||||
"""Generate outputs for a request.
|
||||
|
||||
@ -603,14 +607,12 @@ class AsyncLLMEngine:
|
||||
from the LLMEngine to the caller.
|
||||
|
||||
Args:
|
||||
prompt: The prompt string. Can be None if prompt_token_ids is
|
||||
provided.
|
||||
inputs: The inputs to the LLM. See
|
||||
:class:`~vllm.inputs.PromptInputs`
|
||||
for more details about the format of each input.
|
||||
sampling_params: The sampling parameters of the request.
|
||||
request_id: The unique id of the request.
|
||||
prompt_token_ids: The token IDs of the prompt. If None, we
|
||||
use the tokenizer to convert the prompts to token IDs.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
multi_modal_data: Multi modal data per request.
|
||||
|
||||
Yields:
|
||||
The output `RequestOutput` objects from the LLMEngine
|
||||
@ -659,24 +661,20 @@ class AsyncLLMEngine:
|
||||
>>> # Process and return the final output
|
||||
>>> ...
|
||||
"""
|
||||
async for output in self.process_request(
|
||||
async for output in self._process_request(
|
||||
request_id,
|
||||
prompt,
|
||||
inputs,
|
||||
sampling_params,
|
||||
prompt_token_ids,
|
||||
lora_request,
|
||||
multi_modal_data,
|
||||
lora_request=lora_request,
|
||||
):
|
||||
yield output
|
||||
yield LLMEngine.validate_output(output, RequestOutput)
|
||||
|
||||
async def encode(
|
||||
self,
|
||||
prompt: Optional[str],
|
||||
inputs: PromptInputs,
|
||||
pooling_params: PoolingParams,
|
||||
request_id: str,
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None
|
||||
) -> AsyncIterator[EmbeddingRequestOutput]:
|
||||
"""Generate outputs for a request from an embedding model.
|
||||
|
||||
@ -685,14 +683,12 @@ class AsyncLLMEngine:
|
||||
from the LLMEngine to the caller.
|
||||
|
||||
Args:
|
||||
prompt: The prompt string. Can be None if prompt_token_ids is
|
||||
provided.
|
||||
inputs: The inputs to the LLM. See
|
||||
:class:`~vllm.inputs.PromptInputs`
|
||||
for more details about the format of each input.
|
||||
pooling_params: The pooling parameters of the request.
|
||||
request_id: The unique id of the request.
|
||||
prompt_token_ids: The token IDs of the prompt. If None, we
|
||||
use the tokenizer to convert the prompts to token IDs.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
multi_modal_data: Multi modal data per request.
|
||||
|
||||
Yields:
|
||||
The output `EmbeddingRequestOutput` objects from the LLMEngine
|
||||
@ -739,24 +735,21 @@ class AsyncLLMEngine:
|
||||
>>> # Process and return the final output
|
||||
>>> ...
|
||||
"""
|
||||
async for output in self.process_request(
|
||||
async for output in self._process_request(
|
||||
request_id,
|
||||
prompt,
|
||||
inputs,
|
||||
pooling_params,
|
||||
prompt_token_ids,
|
||||
lora_request,
|
||||
multi_modal_data,
|
||||
lora_request=lora_request,
|
||||
):
|
||||
yield output
|
||||
yield LLMEngine.validate_output(output, EmbeddingRequestOutput)
|
||||
|
||||
async def process_request(
|
||||
async def _process_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
inputs: PromptInputs,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
*,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
"""Common logic to process requests with SamplingParams or
|
||||
PoolingParams."""
|
||||
@ -764,12 +757,10 @@ class AsyncLLMEngine:
|
||||
|
||||
stream = await self.add_request(
|
||||
request_id,
|
||||
prompt,
|
||||
inputs,
|
||||
params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -1,5 +1,8 @@
|
||||
import time
|
||||
from typing import Iterable, List, Optional, Type, Union
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, ClassVar, Iterable, List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Type, TypeVar, Union
|
||||
|
||||
from transformers import GenerationConfig, PreTrainedTokenizer
|
||||
|
||||
@ -18,6 +21,7 @@ from vllm.engine.output_processor.stop_checker import StopChecker
|
||||
from vllm.engine.output_processor.util import create_output_by_sequence_group
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.inputs import LLMInputs, PromptInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
|
||||
@ -25,8 +29,8 @@ from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest,
|
||||
MultiModalData, PoolerOutput, SamplerOutput,
|
||||
Sequence, SequenceGroup, SequenceGroupMetadata,
|
||||
PoolerOutput, SamplerOutput, Sequence,
|
||||
SequenceGroup, SequenceGroupMetadata,
|
||||
SequenceStatus)
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
|
||||
@ -50,6 +54,9 @@ def _load_generation_config_dict(model_config: ModelConfig):
|
||||
return {}
|
||||
|
||||
|
||||
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
"""An LLM engine that receives requests and generates texts.
|
||||
|
||||
@ -60,11 +67,11 @@ class LLMEngine:
|
||||
iteration-level scheduling and efficient memory management to maximize the
|
||||
serving throughput.
|
||||
|
||||
The `LLM` class wraps this class for offline batched inference and the
|
||||
`AsyncLLMEngine` class wraps this class for online serving.
|
||||
The :class:`~vllm.LLM` class wraps this class for offline batched inference
|
||||
and the :class:`AsyncLLMEngine` class wraps this class for online serving.
|
||||
|
||||
NOTE: The config arguments are derived from the `EngineArgs` class. For the
|
||||
comprehensive list of arguments, see `EngineArgs`.
|
||||
NOTE: The config arguments are derived from the :class:`~vllm.EngineArgs`
|
||||
class. For the comprehensive list of arguments, see :ref:`engine_args`.
|
||||
|
||||
Args:
|
||||
model_config: The configuration related to the LLM model.
|
||||
@ -81,9 +88,60 @@ class LLMEngine:
|
||||
executor_class: The model executor class for managing distributed
|
||||
execution.
|
||||
log_stats: Whether to log statistics.
|
||||
usage_context: Specified entry point, used for usage info collection
|
||||
usage_context: Specified entry point, used for usage info collection.
|
||||
"""
|
||||
|
||||
DO_VALIDATE_OUTPUT: ClassVar[bool] = False
|
||||
"""A flag to toggle whether to validate the type of request output."""
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def enable_output_validation(cls):
|
||||
cls.DO_VALIDATE_OUTPUT = True
|
||||
|
||||
yield
|
||||
|
||||
cls.DO_VALIDATE_OUTPUT = False
|
||||
|
||||
@classmethod
|
||||
def validate_output(
|
||||
cls,
|
||||
output: object,
|
||||
output_type: Type[_O],
|
||||
) -> _O:
|
||||
do_validate = cls.DO_VALIDATE_OUTPUT
|
||||
|
||||
if ((TYPE_CHECKING or do_validate)
|
||||
and not isinstance(output, output_type)):
|
||||
raise TypeError(f"Expected output of type {output_type}, "
|
||||
f"but found type {type(output)}")
|
||||
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
def validate_outputs(
|
||||
cls,
|
||||
outputs: GenericSequence[object],
|
||||
output_type: Type[_O],
|
||||
) -> List[_O]:
|
||||
do_validate = cls.DO_VALIDATE_OUTPUT
|
||||
|
||||
outputs_: List[_O]
|
||||
if TYPE_CHECKING or do_validate:
|
||||
outputs_ = []
|
||||
for output in outputs:
|
||||
if not isinstance(output, output_type):
|
||||
raise TypeError(f"Expected output of type {output_type}, "
|
||||
f"but found type {type(output)}")
|
||||
|
||||
outputs_.append(output)
|
||||
else:
|
||||
outputs_ = outputs
|
||||
|
||||
return outputs_
|
||||
|
||||
tokenizer: Optional[BaseTokenizerGroup]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
@ -151,12 +209,11 @@ class LLMEngine:
|
||||
self.log_stats = log_stats
|
||||
|
||||
if not self.model_config.skip_tokenizer_init:
|
||||
self.tokenizer: BaseTokenizerGroup
|
||||
self._init_tokenizer()
|
||||
self.tokenizer = self._init_tokenizer()
|
||||
self.detokenizer = Detokenizer(self.tokenizer)
|
||||
else:
|
||||
self.detokenizer = None
|
||||
self.tokenizer = None
|
||||
self.detokenizer = None
|
||||
|
||||
self.seq_counter = Counter()
|
||||
self.generation_config_fields = _load_generation_config_dict(
|
||||
@ -318,14 +375,26 @@ class LLMEngine:
|
||||
if model_executor := getattr(self, "model_executor", None):
|
||||
model_executor.shutdown()
|
||||
|
||||
MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because "
|
||||
"skip_tokenizer_init is True")
|
||||
|
||||
def get_tokenizer_group(
|
||||
self,
|
||||
fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup:
|
||||
if self.tokenizer is None:
|
||||
raise ValueError(fail_msg)
|
||||
|
||||
return self.tokenizer
|
||||
|
||||
def get_tokenizer(self) -> "PreTrainedTokenizer":
|
||||
return self.tokenizer.get_lora_tokenizer(None)
|
||||
return self.get_tokenizer_group().get_lora_tokenizer(None)
|
||||
|
||||
def get_tokenizer_for_seq(self,
|
||||
sequence: Sequence) -> "PreTrainedTokenizer":
|
||||
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
|
||||
return self.get_tokenizer_group().get_lora_tokenizer(
|
||||
sequence.lora_request)
|
||||
|
||||
def _init_tokenizer(self, **tokenizer_init_kwargs):
|
||||
def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup:
|
||||
init_kwargs = dict(
|
||||
tokenizer_id=self.model_config.tokenizer,
|
||||
enable_lora=bool(self.lora_config),
|
||||
@ -335,8 +404,9 @@ class LLMEngine:
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
revision=self.model_config.tokenizer_revision)
|
||||
init_kwargs.update(tokenizer_init_kwargs)
|
||||
self.tokenizer = get_tokenizer_group(
|
||||
self.parallel_config.tokenizer_pool_config, **init_kwargs)
|
||||
|
||||
return get_tokenizer_group(self.parallel_config.tokenizer_pool_config,
|
||||
**init_kwargs)
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||
@ -346,29 +416,85 @@ class LLMEngine:
|
||||
self.lora_config.verify_with_scheduler_config(
|
||||
self.scheduler_config)
|
||||
|
||||
def encode_request(
|
||||
def _get_eos_token_id(
|
||||
self, lora_request: Optional[LoRARequest]) -> Optional[int]:
|
||||
if self.tokenizer is None:
|
||||
logger.warning("Using None for EOS token id because tokenizer "
|
||||
"is not initialized")
|
||||
return None
|
||||
|
||||
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id
|
||||
|
||||
def _add_processed_request(
|
||||
self,
|
||||
request_id: str, # pylint: disable=unused-argument
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
request_id: str,
|
||||
processed_inputs: LLMInputs,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: float,
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> None:
|
||||
# Create the sequences.
|
||||
block_size = self.cache_config.block_size
|
||||
seq_id = next(self.seq_counter)
|
||||
eos_token_id = self._get_eos_token_id(lora_request)
|
||||
|
||||
seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
|
||||
lora_request)
|
||||
|
||||
# Create a SequenceGroup based on SamplingParams or PoolingParams
|
||||
if isinstance(params, SamplingParams):
|
||||
seq_group = self._create_sequence_group_with_sampling(
|
||||
request_id,
|
||||
seq,
|
||||
params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
elif isinstance(params, PoolingParams):
|
||||
seq_group = self._create_sequence_group_with_pooling(
|
||||
request_id,
|
||||
seq,
|
||||
params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either SamplingParams or PoolingParams must be provided.")
|
||||
|
||||
# Add the sequence group to the scheduler.
|
||||
self.scheduler.add_seq_group(seq_group)
|
||||
|
||||
def process_model_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: PromptInputs,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
):
|
||||
if prompt_token_ids is None:
|
||||
assert prompt is not None
|
||||
prompt_token_ids = self.tokenizer.encode(request_id=request_id,
|
||||
prompt=prompt,
|
||||
lora_request=lora_request)
|
||||
return prompt_token_ids
|
||||
) -> LLMInputs:
|
||||
if isinstance(inputs, str):
|
||||
inputs = {"prompt": inputs}
|
||||
|
||||
if "prompt_token_ids" not in inputs:
|
||||
tokenizer = self.get_tokenizer_group("prompts must be None if "
|
||||
"skip_tokenizer_init is True")
|
||||
|
||||
prompt_token_ids = tokenizer.encode(request_id=request_id,
|
||||
prompt=inputs["prompt"],
|
||||
lora_request=lora_request)
|
||||
else:
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
|
||||
return LLMInputs(prompt_token_ids=prompt_token_ids,
|
||||
prompt=inputs.get("prompt"),
|
||||
multi_modal_data=inputs.get("multi_modal_data"))
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
inputs: PromptInputs,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> None:
|
||||
"""Add a request to the engine's request pool.
|
||||
|
||||
@ -378,15 +504,14 @@ class LLMEngine:
|
||||
|
||||
Args:
|
||||
request_id: The unique ID of the request.
|
||||
prompt: The prompt string. Can be None if prompt_token_ids is
|
||||
provided.
|
||||
params: Parameters for sampling or pooling. SamplingParams
|
||||
for text generation. PoolingParams for pooling.
|
||||
prompt_token_ids: The token IDs of the prompt. If None, we
|
||||
use the tokenizer to convert the prompts to token IDs.
|
||||
inputs: The inputs to the LLM. See
|
||||
:class:`~vllm.inputs.PromptInputs`
|
||||
for more details about the format of each input.
|
||||
params: Parameters for sampling or pooling.
|
||||
:class:`~vllm.SamplingParams` for text generation.
|
||||
:class:`~vllm.PoolingParams` for pooling.
|
||||
arrival_time: The arrival time of the request. If None, we use
|
||||
the current monotonic time.
|
||||
multi_modal_data: Multi modal data per request.
|
||||
|
||||
Details:
|
||||
- Set arrival_time to the current time if it is None.
|
||||
@ -417,59 +542,26 @@ class LLMEngine:
|
||||
"not enabled!")
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
prompt_token_ids = self.encode_request(
|
||||
|
||||
processed_inputs = self.process_model_inputs(request_id=request_id,
|
||||
inputs=inputs,
|
||||
lora_request=lora_request)
|
||||
|
||||
self._add_processed_request(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
lora_request=lora_request)
|
||||
|
||||
# Create the sequences.
|
||||
block_size = self.cache_config.block_size
|
||||
seq_id = next(self.seq_counter)
|
||||
eos_token_id = None
|
||||
if self.tokenizer:
|
||||
eos_token_id = self.tokenizer.get_lora_tokenizer(
|
||||
lora_request).eos_token_id
|
||||
else:
|
||||
logger.warning("Use None for EOS token id because tokenizer is "
|
||||
"not initialized")
|
||||
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
|
||||
eos_token_id, lora_request)
|
||||
|
||||
# Create a SequenceGroup based on SamplingParams or PoolingParams
|
||||
if isinstance(params, SamplingParams):
|
||||
seq_group = self._create_sequence_group_with_sampling(
|
||||
request_id,
|
||||
seq,
|
||||
params,
|
||||
arrival_time,
|
||||
lora_request,
|
||||
multi_modal_data,
|
||||
)
|
||||
elif isinstance(params, PoolingParams):
|
||||
seq_group = self._create_sequence_group_with_pooling(
|
||||
request_id,
|
||||
seq,
|
||||
params,
|
||||
arrival_time,
|
||||
lora_request,
|
||||
multi_modal_data,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either SamplingParams or PoolingParams must be provided.")
|
||||
|
||||
# Add the sequence group to the scheduler.
|
||||
self.scheduler.add_seq_group(seq_group)
|
||||
processed_inputs=processed_inputs,
|
||||
params=params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
def _create_sequence_group_with_sampling(
|
||||
self,
|
||||
request_id: str,
|
||||
seq: Sequence,
|
||||
sampling_params: SamplingParams,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
arrival_time: float,
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> SequenceGroup:
|
||||
"""Creates a SequenceGroup with SamplingParams."""
|
||||
max_logprobs = self.get_model_config().max_logprobs
|
||||
@ -495,8 +587,7 @@ class LLMEngine:
|
||||
seqs=[seq],
|
||||
arrival_time=arrival_time,
|
||||
sampling_params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
multi_modal_data=multi_modal_data)
|
||||
lora_request=lora_request)
|
||||
|
||||
return seq_group
|
||||
|
||||
@ -505,9 +596,8 @@ class LLMEngine:
|
||||
request_id: str,
|
||||
seq: Sequence,
|
||||
pooling_params: PoolingParams,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
arrival_time: float,
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> SequenceGroup:
|
||||
"""Creates a SequenceGroup with PoolingParams."""
|
||||
# Defensive copy of PoolingParams, which are used by the pooler
|
||||
@ -517,7 +607,6 @@ class LLMEngine:
|
||||
seqs=[seq],
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
multi_modal_data=multi_modal_data,
|
||||
pooling_params=pooling_params)
|
||||
return seq_group
|
||||
|
||||
@ -570,7 +659,7 @@ class LLMEngine:
|
||||
|
||||
def _process_model_outputs(
|
||||
self,
|
||||
output: List[Union[SamplerOutput, PoolerOutput]],
|
||||
output: GenericSequence[Union[SamplerOutput, PoolerOutput]],
|
||||
scheduled_seq_groups: List[ScheduledSequenceGroup],
|
||||
ignored_seq_groups: List[SequenceGroup],
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
@ -585,7 +674,7 @@ class LLMEngine:
|
||||
# Organize outputs by [sequence group][step] instead of
|
||||
# [step][sequence group].
|
||||
output_by_sequence_group = create_output_by_sequence_group(
|
||||
sampler_outputs=output, num_seq_groups=len(scheduled_seq_groups))
|
||||
output, num_seq_groups=len(scheduled_seq_groups))
|
||||
|
||||
# Update the scheduled sequence groups with the model outputs.
|
||||
for scheduled_seq_group, outputs, seq_group_meta in zip(
|
||||
|
@ -1,18 +1,20 @@
|
||||
from typing import List
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Union
|
||||
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupOutput
|
||||
from vllm.sequence import PoolerOutput, SamplerOutput, SequenceGroupOutput
|
||||
|
||||
|
||||
def create_output_by_sequence_group(
|
||||
sampler_outputs: List[SamplerOutput],
|
||||
outputs: GenericSequence[Union[SamplerOutput, PoolerOutput]],
|
||||
num_seq_groups: int) -> List[List[SequenceGroupOutput]]:
|
||||
"""Helper method which transforms a 2d list organized by
|
||||
[step][sequence group] into [sequence group][step].
|
||||
"""
|
||||
output_by_sequence_group: List[List[SamplerOutput]] = [
|
||||
output_by_sequence_group: List[List[SequenceGroupOutput]] = [
|
||||
[] for _ in range(num_seq_groups)
|
||||
]
|
||||
for step in sampler_outputs:
|
||||
for step in outputs:
|
||||
for i, sequence_group_output in enumerate(step):
|
||||
output_by_sequence_group[i].append(sequence_group_output)
|
||||
|
||||
|
@ -1,11 +1,14 @@
|
||||
from typing import List, Optional, Union
|
||||
from contextlib import contextmanager
|
||||
from typing import ClassVar, List, Optional, Sequence, Union, cast, overload
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.inputs import (PromptInputs, PromptStrictInputs, TextPrompt,
|
||||
TextTokensPrompt, TokensPrompt,
|
||||
parse_and_batch_prompt)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
@ -13,7 +16,7 @@ from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import MultiModalData
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Counter
|
||||
from vllm.utils import Counter, deprecate_kwargs
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -28,8 +31,10 @@ class LLM:
|
||||
mechanism and efficient memory management.
|
||||
|
||||
NOTE: This class is intended to be used for offline inference. For online
|
||||
serving, use the `AsyncLLMEngine` class instead.
|
||||
NOTE: For the comprehensive list of arguments, see `EngineArgs`.
|
||||
serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
|
||||
|
||||
NOTE: For the comprehensive list of arguments, see
|
||||
:class:`~vllm.EngineArgs`.
|
||||
|
||||
Args:
|
||||
model: The name or path of a HuggingFace Transformers model.
|
||||
@ -81,6 +86,18 @@ class LLM:
|
||||
disable_custom_all_reduce: See ParallelConfig
|
||||
"""
|
||||
|
||||
DEPRECATE_LEGACY: ClassVar[bool] = False
|
||||
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def deprecate_legacy_api(cls):
|
||||
cls.DEPRECATE_LEGACY = True
|
||||
|
||||
yield
|
||||
|
||||
cls.DEPRECATE_LEGACY = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
@ -138,15 +155,101 @@ class LLM:
|
||||
) -> None:
|
||||
self.llm_engine.tokenizer.tokenizer = tokenizer
|
||||
|
||||
@overload # LEGACY: single (prompt + optional token ids)
|
||||
def generate(
|
||||
self,
|
||||
prompts: Optional[Union[str, List[str]]] = None,
|
||||
prompts: str,
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
List[SamplingParams]]] = None,
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: multi (prompt + optional token ids)
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
List[SamplingParams]]] = None,
|
||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: single (token ids + optional prompt)
|
||||
def generate(
|
||||
self,
|
||||
prompts: Optional[str] = None,
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
List[SamplingParams]]] = None,
|
||||
*,
|
||||
prompt_token_ids: List[int],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: multi (token ids + optional prompt)
|
||||
def generate(
|
||||
self,
|
||||
prompts: Optional[List[str]] = None,
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
List[SamplingParams]]] = None,
|
||||
*,
|
||||
prompt_token_ids: List[List[int]],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: single or multi token ids [pos-only]
|
||||
def generate(
|
||||
self,
|
||||
prompts: None,
|
||||
sampling_params: None,
|
||||
prompt_token_ids: Union[List[int], List[List[int]]],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
self,
|
||||
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
|
||||
/, # We may enable `inputs` keyword after removing the old API
|
||||
*,
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
Sequence[SamplingParams]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> List[RequestOutput]:
|
||||
...
|
||||
|
||||
@deprecate_kwargs("prompts",
|
||||
"prompt_token_ids",
|
||||
"multi_modal_data",
|
||||
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
|
||||
additional_message="Please use the 'inputs' parameter "
|
||||
"instead.")
|
||||
def generate(
|
||||
self,
|
||||
prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
|
||||
Optional[Union[str, List[str]]]] = None,
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
Sequence[SamplingParams]]] = None,
|
||||
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> List[RequestOutput]:
|
||||
"""Generates the completions for the input prompts.
|
||||
|
||||
@ -155,49 +258,138 @@ class LLM:
|
||||
into a single list and pass it to this method.
|
||||
|
||||
Args:
|
||||
prompts: A list of prompts to generate completions for.
|
||||
inputs: A list of inputs to generate completions for.
|
||||
sampling_params: The sampling parameters for text generation. If
|
||||
None, we use the default sampling parameters.
|
||||
When it is a single value, it is applied to every prompt.
|
||||
When it is a list, the list must have the same length as the
|
||||
prompts and it is paired one by one with the prompt.
|
||||
prompt_token_ids: A list of token IDs for the prompts. If None, we
|
||||
use the tokenizer to convert the prompts to token IDs.
|
||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
multi_modal_data: Multi modal data.
|
||||
|
||||
Returns:
|
||||
A list of `RequestOutput` objects containing the
|
||||
generated completions in the same order as the input prompts.
|
||||
"""
|
||||
if prompt_token_ids is not None or multi_modal_data is not None:
|
||||
inputs = self._convert_v1_inputs(
|
||||
prompts=cast(Optional[Union[str, List[str]]], prompts),
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
else:
|
||||
inputs = cast(
|
||||
Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
|
||||
prompts)
|
||||
|
||||
if sampling_params is None:
|
||||
# Use default sampling params.
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
requests_data = self._validate_and_prepare_requests(
|
||||
prompts,
|
||||
sampling_params,
|
||||
prompt_token_ids,
|
||||
lora_request,
|
||||
multi_modal_data,
|
||||
self._validate_and_add_requests(
|
||||
inputs=inputs,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
# Add requests to the engine and run the engine
|
||||
for request_data in requests_data:
|
||||
self._add_request(**request_data)
|
||||
|
||||
return self._run_engine(use_tqdm)
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
return LLMEngine.validate_outputs(outputs, RequestOutput)
|
||||
|
||||
@overload # LEGACY: single (prompt + optional token ids)
|
||||
def encode(
|
||||
self,
|
||||
prompts: Optional[Union[str, List[str]]] = None,
|
||||
prompts: str,
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
List[PoolingParams]]] = None,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: multi (prompt + optional token ids)
|
||||
def encode(
|
||||
self,
|
||||
prompts: List[str],
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: single (token ids + optional prompt)
|
||||
def encode(
|
||||
self,
|
||||
prompts: Optional[str] = None,
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
*,
|
||||
prompt_token_ids: List[int],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: multi (token ids + optional prompt)
|
||||
def encode(
|
||||
self,
|
||||
prompts: Optional[List[str]] = None,
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
*,
|
||||
prompt_token_ids: List[List[int]],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
...
|
||||
|
||||
@overload # LEGACY: single or multi token ids [pos-only]
|
||||
def encode(
|
||||
self,
|
||||
prompts: None,
|
||||
pooling_params: None,
|
||||
prompt_token_ids: Union[List[int], List[List[int]]],
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def encode(
|
||||
self,
|
||||
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
|
||||
/, # We may enable `inputs` keyword after removing the old API
|
||||
*,
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
...
|
||||
|
||||
@deprecate_kwargs("prompts",
|
||||
"prompt_token_ids",
|
||||
"multi_modal_data",
|
||||
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
|
||||
additional_message="Please use the 'inputs' parameter "
|
||||
"instead.")
|
||||
def encode(
|
||||
self,
|
||||
prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
|
||||
Optional[Union[str, List[str]]]] = None,
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> List[EmbeddingRequestOutput]:
|
||||
"""Generates the completions for the input prompts.
|
||||
|
||||
@ -206,124 +398,133 @@ class LLM:
|
||||
into a single list and pass it to this method.
|
||||
|
||||
Args:
|
||||
prompts: A list of prompts to generate completions for.
|
||||
inputs: The inputs to the LLM. You may pass a sequence of inputs for
|
||||
batch inference. See :class:`~vllm.inputs.PromptStrictInputs`
|
||||
for more details about the format of each input.
|
||||
pooling_params: The pooling parameters for pooling. If None, we
|
||||
use the default pooling parameters.
|
||||
prompt_token_ids: A list of token IDs for the prompts. If None, we
|
||||
use the tokenizer to convert the prompts to token IDs.
|
||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
multi_modal_data: Multi modal data.
|
||||
|
||||
Returns:
|
||||
A list of `EmbeddingRequestOutput` objects containing the
|
||||
generated embeddings in the same order as the input prompts.
|
||||
"""
|
||||
if prompt_token_ids is not None or multi_modal_data is not None:
|
||||
inputs = self._convert_v1_inputs(
|
||||
prompts=cast(Optional[Union[str, List[str]]], prompts),
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
else:
|
||||
inputs = cast(
|
||||
Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
|
||||
prompts)
|
||||
|
||||
if pooling_params is None:
|
||||
# Use default pooling params.
|
||||
pooling_params = PoolingParams()
|
||||
|
||||
requests_data = self._validate_and_prepare_requests(
|
||||
prompts,
|
||||
pooling_params,
|
||||
prompt_token_ids,
|
||||
lora_request,
|
||||
multi_modal_data,
|
||||
self._validate_and_add_requests(
|
||||
inputs=inputs,
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
# Add requests to the engine and run the engine
|
||||
for request_data in requests_data:
|
||||
self._add_request(**request_data)
|
||||
outputs = self._run_engine(use_tqdm=use_tqdm)
|
||||
return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput)
|
||||
|
||||
return self._run_engine(use_tqdm)
|
||||
|
||||
def _validate_and_prepare_requests(
|
||||
# LEGACY
|
||||
def _convert_v1_inputs(
|
||||
self,
|
||||
prompts: Optional[Union[str, List[str]]],
|
||||
params: Union[Union[SamplingParams, PoolingParams],
|
||||
List[Union[SamplingParams,
|
||||
PoolingParams]]], # Unified parameter
|
||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> List[dict]:
|
||||
"""Validates and prepares request data for adding to the engine.
|
||||
|
||||
Ensures prompts and token IDs are consistent, and returns a list of
|
||||
dictionaries with request data for further processing.
|
||||
"""
|
||||
if prompts is None and prompt_token_ids is None:
|
||||
raise ValueError("Either prompts or prompt_token_ids must be "
|
||||
"provided.")
|
||||
if self.llm_engine.model_config.skip_tokenizer_init \
|
||||
and prompts is not None:
|
||||
raise ValueError("prompts must be None if skip_tokenizer_init "
|
||||
"is True")
|
||||
if isinstance(prompts, str):
|
||||
# Convert a single prompt to a list.
|
||||
prompts = [prompts]
|
||||
if (prompts is not None and prompt_token_ids is not None
|
||||
and len(prompts) != len(prompt_token_ids)):
|
||||
raise ValueError("The lengths of prompts and prompt_token_ids "
|
||||
"must be the same.")
|
||||
prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
|
||||
multi_modal_data: Optional[MultiModalData],
|
||||
):
|
||||
# skip_tokenizer_init is now checked in engine
|
||||
|
||||
if prompts is not None:
|
||||
prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
|
||||
if prompt_token_ids is not None:
|
||||
prompt_token_ids = [
|
||||
p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
|
||||
]
|
||||
|
||||
num_requests = None
|
||||
if prompts is not None:
|
||||
num_requests = len(prompts)
|
||||
else:
|
||||
assert prompt_token_ids is not None
|
||||
if prompt_token_ids is not None:
|
||||
if (num_requests is not None
|
||||
and num_requests != len(prompt_token_ids)):
|
||||
raise ValueError("The lengths of prompts and prompt_token_ids "
|
||||
"must be the same.")
|
||||
|
||||
num_requests = len(prompt_token_ids)
|
||||
if num_requests is None:
|
||||
raise ValueError("Either prompts or prompt_token_ids must be "
|
||||
"provided.")
|
||||
|
||||
inputs: List[PromptInputs] = []
|
||||
for i in range(num_requests):
|
||||
if prompts is not None:
|
||||
if prompt_token_ids is not None:
|
||||
item = TextTokensPrompt(
|
||||
prompt=prompts[i],
|
||||
prompt_token_ids=prompt_token_ids[i])
|
||||
else:
|
||||
item = TextPrompt(prompt=prompts[i])
|
||||
else:
|
||||
if prompt_token_ids is not None:
|
||||
item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
|
||||
else:
|
||||
raise AssertionError
|
||||
|
||||
if multi_modal_data is not None:
|
||||
item["multi_modal_data"] = multi_modal_data
|
||||
|
||||
inputs.append(item)
|
||||
|
||||
return inputs
|
||||
|
||||
def _validate_and_add_requests(
|
||||
self,
|
||||
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
|
||||
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
|
||||
Sequence[PoolingParams]],
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> None:
|
||||
if isinstance(inputs, (str, dict)):
|
||||
# Convert a single prompt to a list.
|
||||
inputs = [inputs]
|
||||
|
||||
num_requests = len(inputs)
|
||||
|
||||
if isinstance(params, list) and len(params) != num_requests:
|
||||
raise ValueError("The lengths of prompts and params "
|
||||
"must be the same.")
|
||||
if multi_modal_data:
|
||||
multi_modal_data.data = multi_modal_data.data.to(torch.float16)
|
||||
|
||||
# Add requests to the engine.
|
||||
requests_data = []
|
||||
for i in range(num_requests):
|
||||
prompt = prompts[i] if prompts is not None else None
|
||||
token_ids = None if prompt_token_ids is None else prompt_token_ids[
|
||||
i]
|
||||
|
||||
multi_modal_item = MultiModalData(
|
||||
type=multi_modal_data.type,
|
||||
data=multi_modal_data.data[i].unsqueeze(0),
|
||||
) if multi_modal_data else None
|
||||
|
||||
requests_data.append({
|
||||
"prompt":
|
||||
prompt,
|
||||
"params":
|
||||
params[i] if isinstance(params, list) else params,
|
||||
"prompt_token_ids":
|
||||
token_ids,
|
||||
"lora_request":
|
||||
lora_request,
|
||||
"multi_modal_data":
|
||||
multi_modal_item,
|
||||
})
|
||||
|
||||
return requests_data
|
||||
for i, request_inputs in enumerate(inputs):
|
||||
self._add_request(
|
||||
request_inputs,
|
||||
params[i] if isinstance(params, Sequence) else params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
def _add_request(
|
||||
self,
|
||||
prompt: Optional[str],
|
||||
inputs: PromptInputs,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
prompt_token_ids: Optional[List[int]],
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
) -> None:
|
||||
request_id = str(next(self.request_counter))
|
||||
self.llm_engine.add_request(request_id,
|
||||
prompt,
|
||||
inputs,
|
||||
params,
|
||||
prompt_token_ids,
|
||||
lora_request=lora_request,
|
||||
multi_modal_data=multi_modal_data)
|
||||
lora_request=lora_request)
|
||||
|
||||
def _run_engine(
|
||||
self, use_tqdm: bool
|
||||
self, *, use_tqdm: bool
|
||||
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
# Initialize tqdm.
|
||||
if use_tqdm:
|
||||
@ -355,5 +556,4 @@ class LLM:
|
||||
# Sort the outputs by request ID.
|
||||
# This is necessary because some requests may be finished earlier than
|
||||
# its previous requests.
|
||||
outputs = sorted(outputs, key=lambda x: int(x.request_id))
|
||||
return outputs
|
||||
return sorted(outputs, key=lambda x: int(x.request_id))
|
||||
|
@ -176,9 +176,15 @@ class OpenAIServingChat(OpenAIServing):
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator = self.engine.generate(prompt_text, sampling_params,
|
||||
request_id, prompt_ids,
|
||||
lora_request)
|
||||
result_generator = self.engine.generate(
|
||||
{
|
||||
"prompt": prompt_text,
|
||||
"prompt_token_ids": prompt_ids
|
||||
},
|
||||
sampling_params,
|
||||
request_id,
|
||||
lora_request,
|
||||
)
|
||||
# Streaming response
|
||||
if request.stream:
|
||||
return self.chat_completion_stream_generator(
|
||||
|
@ -119,12 +119,17 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
truncate_prompt_tokens)
|
||||
prompt_ids, prompt_text = prompt_formats
|
||||
|
||||
generators.append(
|
||||
self.engine.generate(prompt_text,
|
||||
sampling_params,
|
||||
f"{request_id}-{i}",
|
||||
prompt_token_ids=prompt_ids,
|
||||
lora_request=lora_request))
|
||||
generator = self.engine.generate(
|
||||
{
|
||||
"prompt": prompt_text,
|
||||
"prompt_token_ids": prompt_ids
|
||||
},
|
||||
sampling_params,
|
||||
f"{request_id}-{i}",
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
@ -1,5 +1,5 @@
|
||||
import time
|
||||
from typing import AsyncIterator, List, Tuple
|
||||
from typing import AsyncIterator, List, Optional, Tuple
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
@ -100,11 +100,16 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
|
||||
prompt_ids, prompt_text = prompt_formats
|
||||
|
||||
generators.append(
|
||||
self.engine.generate(prompt_text,
|
||||
pooling_params,
|
||||
f"{request_id}-{i}",
|
||||
prompt_token_ids=prompt_ids))
|
||||
generator = self.engine.encode(
|
||||
{
|
||||
"prompt": prompt_text,
|
||||
"prompt_token_ids": prompt_ids
|
||||
},
|
||||
pooling_params,
|
||||
f"{request_id}-{i}",
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
@ -113,16 +118,21 @@ class OpenAIServingEmbedding(OpenAIServing):
|
||||
int, EmbeddingRequestOutput]] = merge_async_iterators(*generators)
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: EmbeddingRequestOutput = [None] * len(prompts)
|
||||
async for i, res in result_generator:
|
||||
if await raw_request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await self.engine.abort(f"{request_id}-{i}")
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response("Client disconnected")
|
||||
final_res_batch[i] = res
|
||||
response = request_output_to_embedding_response(
|
||||
final_res_batch, request_id, created_time, model_name)
|
||||
final_res_batch: List[Optional[EmbeddingRequestOutput]]
|
||||
final_res_batch = [None] * len(prompts)
|
||||
try:
|
||||
async for i, res in result_generator:
|
||||
if await raw_request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await self.engine.abort(f"{request_id}-{i}")
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response("Client disconnected")
|
||||
final_res_batch[i] = res
|
||||
response = request_output_to_embedding_response(
|
||||
final_res_batch, request_id, created_time, model_name)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
return response
|
||||
|
||||
|
@ -143,7 +143,8 @@ class OpenAIServing:
|
||||
return json_str
|
||||
|
||||
async def _check_model(
|
||||
self, request: Union[CompletionRequest, ChatCompletionRequest]
|
||||
self, request: Union[CompletionRequest, ChatCompletionRequest,
|
||||
EmbeddingRequest]
|
||||
) -> Optional[ErrorResponse]:
|
||||
if request.model in self.served_model_names:
|
||||
return None
|
||||
@ -155,7 +156,8 @@ class OpenAIServing:
|
||||
status_code=HTTPStatus.NOT_FOUND)
|
||||
|
||||
def _maybe_get_lora(
|
||||
self, request: Union[CompletionRequest, ChatCompletionRequest]
|
||||
self, request: Union[CompletionRequest, ChatCompletionRequest,
|
||||
EmbeddingRequest]
|
||||
) -> Optional[LoRARequest]:
|
||||
if request.model in self.served_model_names:
|
||||
return None
|
||||
|
130
vllm/inputs.py
Normal file
130
vllm/inputs.py
Normal file
@ -0,0 +1,130 @@
|
||||
from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence,
|
||||
TypedDict, Union, cast, overload)
|
||||
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.sequence import MultiModalData
|
||||
|
||||
|
||||
class ParsedText(TypedDict):
|
||||
content: str
|
||||
is_tokens: Literal[False]
|
||||
|
||||
|
||||
class ParsedTokens(TypedDict):
|
||||
content: List[int]
|
||||
is_tokens: Literal[True]
|
||||
|
||||
|
||||
# https://github.com/vllm-project/vllm/pull/4028
|
||||
@overload
|
||||
def parse_and_batch_prompt(
|
||||
prompt: Union[str, List[str]]) -> Sequence[ParsedText]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def parse_and_batch_prompt(
|
||||
prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]:
|
||||
...
|
||||
|
||||
|
||||
def parse_and_batch_prompt(
|
||||
prompt: Union[str, List[str], List[int], List[List[int]]],
|
||||
) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
|
||||
if isinstance(prompt, str):
|
||||
# case 1: a string
|
||||
return [ParsedText(content=prompt, is_tokens=False)]
|
||||
|
||||
if isinstance(prompt, list):
|
||||
if len(prompt) == 0:
|
||||
raise ValueError("please provide at least one prompt")
|
||||
|
||||
if isinstance(prompt[0], str):
|
||||
# case 2: array of strings
|
||||
return [
|
||||
ParsedText(content=elem, is_tokens=False)
|
||||
for elem in cast(List[str], prompt)
|
||||
]
|
||||
if isinstance(prompt[0], int):
|
||||
# case 3: array of tokens
|
||||
elem = cast(List[int], prompt)
|
||||
return [ParsedTokens(content=elem, is_tokens=True)]
|
||||
if isinstance(prompt[0], list):
|
||||
if len(prompt[0]) == 0:
|
||||
raise ValueError("please provide at least one prompt")
|
||||
|
||||
if isinstance(prompt[0][0], int):
|
||||
# case 4: array of token arrays
|
||||
return [
|
||||
ParsedTokens(content=elem, is_tokens=True)
|
||||
for elem in cast(List[List[int]], prompt)
|
||||
]
|
||||
|
||||
raise ValueError("prompt must be a string, array of strings, "
|
||||
"array of tokens, or array of token arrays")
|
||||
|
||||
|
||||
class TextPrompt(TypedDict):
|
||||
"""Schema for a text prompt."""
|
||||
|
||||
prompt: str
|
||||
"""The input text to be tokenized before passing to the model."""
|
||||
|
||||
multi_modal_data: NotRequired["MultiModalData"]
|
||||
"""
|
||||
Optional multi-modal data to pass to the model,
|
||||
if the model supports it.
|
||||
"""
|
||||
|
||||
|
||||
class TokensPrompt(TypedDict):
|
||||
"""Schema for a tokenized prompt."""
|
||||
|
||||
prompt_token_ids: List[int]
|
||||
"""A list of token IDs to pass to the model."""
|
||||
|
||||
multi_modal_data: NotRequired["MultiModalData"]
|
||||
"""
|
||||
Optional multi-modal data to pass to the model,
|
||||
if the model supports it.
|
||||
"""
|
||||
|
||||
|
||||
class TextTokensPrompt(TypedDict):
|
||||
"""It is assumed that :attr:`prompt` is consistent with
|
||||
:attr:`prompt_token_ids`. This is currently used in
|
||||
:class:`AsyncLLMEngine` for logging both the text and token IDs."""
|
||||
|
||||
prompt: str
|
||||
"""The prompt text."""
|
||||
|
||||
prompt_token_ids: List[int]
|
||||
"""The token IDs of the prompt. If None, we use the
|
||||
tokenizer to convert the prompts to token IDs."""
|
||||
|
||||
multi_modal_data: NotRequired["MultiModalData"]
|
||||
"""
|
||||
Optional multi-modal data to pass to the model,
|
||||
if the model supports it.
|
||||
"""
|
||||
|
||||
|
||||
PromptStrictInputs = Union[str, TextPrompt, TokensPrompt]
|
||||
"""
|
||||
The inputs to the LLM, which can take one of the following forms:
|
||||
|
||||
- A text prompt (:class:`str` or :class:`TextPrompt`)
|
||||
- A tokenized prompt (:class:`TokensPrompt`)
|
||||
"""
|
||||
|
||||
PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt]
|
||||
"""Same as :const:`PromptStrictInputs` but additionally accepts
|
||||
:class:`TextTokensPrompt`."""
|
||||
|
||||
|
||||
class LLMInputs(TypedDict):
|
||||
prompt_token_ids: List[int]
|
||||
prompt: Optional[str]
|
||||
multi_modal_data: Optional["MultiModalData"]
|
@ -1,4 +1,5 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -6,6 +7,7 @@ from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
|
||||
SequenceGroup, SequenceStatus)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompletionOutput:
|
||||
"""The output data of one completion output of a request.
|
||||
|
||||
@ -24,25 +26,14 @@ class CompletionOutput:
|
||||
lora_request: The LoRA request that was used to generate the output.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index: int,
|
||||
text: str,
|
||||
token_ids: List[int],
|
||||
cumulative_logprob: float,
|
||||
logprobs: Optional[SampleLogprobs],
|
||||
finish_reason: Optional[str] = None,
|
||||
stop_reason: Union[int, str, None] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> None:
|
||||
self.index = index
|
||||
self.text = text
|
||||
self.token_ids = token_ids
|
||||
self.cumulative_logprob = cumulative_logprob
|
||||
self.logprobs = logprobs
|
||||
self.finish_reason = finish_reason
|
||||
self.stop_reason = stop_reason
|
||||
self.lora_request = lora_request
|
||||
index: int
|
||||
text: str
|
||||
token_ids: List[int]
|
||||
cumulative_logprob: float
|
||||
logprobs: Optional[SampleLogprobs]
|
||||
finish_reason: Optional[str] = None
|
||||
stop_reason: Union[int, str, None] = None
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
|
||||
def finished(self) -> bool:
|
||||
return self.finish_reason is not None
|
||||
@ -57,6 +48,7 @@ class CompletionOutput:
|
||||
f"stop_reason={self.stop_reason})")
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingOutput:
|
||||
"""The output data of one completion output of a request.
|
||||
|
||||
@ -65,15 +57,11 @@ class EmbeddingOutput:
|
||||
length of vector depends on the model as listed in the embedding guide.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding: List[float],
|
||||
) -> None:
|
||||
self.embedding = embedding
|
||||
embedding: List[float]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"EmbeddingOutput("
|
||||
f"embedding={len(self.embedding)}")
|
||||
f"embedding={len(self.embedding)})")
|
||||
|
||||
|
||||
class RequestOutput:
|
||||
@ -93,7 +81,7 @@ class RequestOutput:
|
||||
def __init__(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: str,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: List[int],
|
||||
prompt_logprobs: Optional[PromptLogprobs],
|
||||
outputs: List[CompletionOutput],
|
||||
@ -183,7 +171,7 @@ class EmbeddingRequestOutput:
|
||||
finished (bool): A flag indicating whether the embedding is completed.
|
||||
"""
|
||||
|
||||
def __init__(self, request_id: str, outputs: 'EmbeddingOutput',
|
||||
def __init__(self, request_id: str, outputs: "EmbeddingOutput",
|
||||
prompt_token_ids: List[int], finished: bool):
|
||||
self.request_id = request_id
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
|
@ -6,6 +6,7 @@ from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from vllm.block import LogicalTokenBlock
|
||||
from vllm.inputs import LLMInputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@ -210,8 +211,7 @@ class Sequence:
|
||||
|
||||
Args:
|
||||
seq_id: The ID of the sequence.
|
||||
prompt: The prompt of the sequence.
|
||||
prompt_token_ids: The token IDs of the prompt.
|
||||
inputs: The inputs of the sequence.
|
||||
block_size: The block size of the sequence. Should be the same as the
|
||||
block size used by the block manager and cache engine.
|
||||
lora_request: LoRA request.
|
||||
@ -220,25 +220,24 @@ class Sequence:
|
||||
def __init__(
|
||||
self,
|
||||
seq_id: int,
|
||||
prompt: str,
|
||||
prompt_token_ids: List[int],
|
||||
inputs: LLMInputs,
|
||||
block_size: int,
|
||||
eos_token_id: Optional[int] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> None:
|
||||
self.seq_id = seq_id
|
||||
self.prompt = prompt
|
||||
self.inputs = inputs
|
||||
self.block_size = block_size
|
||||
self.eos_token_id = eos_token_id
|
||||
self.lora_request = lora_request
|
||||
|
||||
self.data: SequenceData = SequenceData(prompt_token_ids)
|
||||
self.data = SequenceData(self.prompt_token_ids)
|
||||
self.output_logprobs: SampleLogprobs = []
|
||||
self.output_text = ""
|
||||
|
||||
self.logical_token_blocks: List[LogicalTokenBlock] = []
|
||||
# Initialize the logical token blocks with the prompt token ids.
|
||||
self._append_tokens_to_blocks(prompt_token_ids)
|
||||
self._append_tokens_to_blocks(self.prompt_token_ids)
|
||||
self.status = SequenceStatus.WAITING
|
||||
self.stop_reason: Union[int, str, None] = None
|
||||
|
||||
@ -248,6 +247,18 @@ class Sequence:
|
||||
# Input + output tokens
|
||||
self.tokens: Optional[List[str]] = None
|
||||
|
||||
@property
|
||||
def prompt(self) -> Optional[str]:
|
||||
return self.inputs["prompt"]
|
||||
|
||||
@property
|
||||
def prompt_token_ids(self) -> List[int]:
|
||||
return self.inputs["prompt_token_ids"]
|
||||
|
||||
@property
|
||||
def multi_modal_data(self) -> Optional["MultiModalData"]:
|
||||
return self.inputs["multi_modal_data"]
|
||||
|
||||
@property
|
||||
def lora_int_id(self) -> int:
|
||||
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||
@ -415,7 +426,6 @@ class SequenceGroup:
|
||||
sampling_params: The sampling parameters used to generate the outputs.
|
||||
arrival_time: The arrival time of the request.
|
||||
lora_request: LoRA request.
|
||||
multi_modal_data: Multi modal data associated with the request.
|
||||
embeddings: The embeddings vectors of the prompt of the sequence group
|
||||
for an embedding model.
|
||||
pooling_params: The pooling parameters used to generate the pooling
|
||||
@ -429,7 +439,6 @@ class SequenceGroup:
|
||||
arrival_time: float,
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
multi_modal_data: Optional[MultiModalData] = None,
|
||||
embeddings: Optional[List[float]] = None,
|
||||
pooling_params: Optional[PoolingParams] = None,
|
||||
) -> None:
|
||||
@ -444,12 +453,11 @@ class SequenceGroup:
|
||||
self.lora_request = lora_request
|
||||
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
||||
self.state = SequenceGroupState()
|
||||
self.multi_modal_data = multi_modal_data
|
||||
self.embeddings = embeddings
|
||||
self.pooling_params = pooling_params
|
||||
|
||||
@property
|
||||
def prompt(self) -> str:
|
||||
def prompt(self) -> Optional[str]:
|
||||
# All sequences in the group should have the same prompt.
|
||||
# We use the prompt of an arbitrary sequence.
|
||||
return next(iter(self.seqs_dict.values())).prompt
|
||||
@ -458,7 +466,13 @@ class SequenceGroup:
|
||||
def prompt_token_ids(self) -> List[int]:
|
||||
# All sequences in the group should have the same prompt.
|
||||
# We use the prompt of an arbitrary sequence.
|
||||
return next(iter(self.seqs_dict.values())).data.prompt_token_ids
|
||||
return next(iter(self.seqs_dict.values())).prompt_token_ids
|
||||
|
||||
@property
|
||||
def multi_modal_data(self) -> Optional[MultiModalData]:
|
||||
# All sequences in the group should have the same multi-modal data.
|
||||
# We use the multi-modal data of an arbitrary sequence.
|
||||
return next(iter(self.seqs_dict.values())).multi_modal_data
|
||||
|
||||
@property
|
||||
def lora_int_id(self) -> int:
|
||||
|
@ -11,7 +11,7 @@ import threading
|
||||
import uuid
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache, partial
|
||||
from functools import lru_cache, partial, wraps
|
||||
from platform import uname
|
||||
from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic,
|
||||
Hashable, List, Optional, OrderedDict, Tuple, TypeVar,
|
||||
@ -658,3 +658,44 @@ def enable_trace_function_call_for_thread() -> None:
|
||||
filename)
|
||||
os.makedirs(os.path.dirname(log_path), exist_ok=True)
|
||||
enable_trace_function_call(log_path)
|
||||
|
||||
|
||||
def identity(value: T) -> T:
|
||||
return value
|
||||
|
||||
|
||||
F = TypeVar('F', bound=Callable[..., Any])
|
||||
|
||||
|
||||
def deprecate_kwargs(
|
||||
*kws: str,
|
||||
is_deprecated: Union[bool, Callable[[], bool]] = True,
|
||||
additional_message: Optional[str] = None) -> Callable[[F], F]:
|
||||
deprecated_kws = set(kws)
|
||||
|
||||
if not callable(is_deprecated):
|
||||
is_deprecated = partial(identity, is_deprecated)
|
||||
|
||||
def wrapper(fn: F) -> F:
|
||||
|
||||
@wraps(fn)
|
||||
def inner(*args, **kwargs):
|
||||
if is_deprecated():
|
||||
deprecated_kwargs = kwargs.keys() & deprecated_kws
|
||||
if deprecated_kwargs:
|
||||
msg = (
|
||||
f"The keyword arguments {deprecated_kwargs} are "
|
||||
"deprecated and will be removed in a future update.")
|
||||
if additional_message is not None:
|
||||
msg += f" {additional_message}"
|
||||
|
||||
warnings.warn(
|
||||
DeprecationWarning(msg),
|
||||
stacklevel=3, # The inner function takes up one level
|
||||
)
|
||||
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return inner # type: ignore
|
||||
|
||||
return wrapper
|
||||
|
Loading…
x
Reference in New Issue
Block a user