rename PromptInputs and inputs with backward compatibility (#8760)

This commit is contained in:
Cyrus Leung 2024-09-26 00:36:47 +08:00 committed by GitHub
parent 0c4d2ad5e6
commit 28e1299e60
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 438 additions and 245 deletions

View File

@ -11,7 +11,7 @@ from tqdm import tqdm
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs
from vllm.inputs import PromptInputs from vllm.inputs import PromptType
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
@ -61,7 +61,7 @@ def main(args: argparse.Namespace):
dummy_prompt_token_ids = np.random.randint(10000, dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size, size=(args.batch_size,
args.input_len)) args.input_len))
dummy_inputs: List[PromptInputs] = [{ dummy_prompts: List[PromptType] = [{
"prompt_token_ids": batch "prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()] } for batch in dummy_prompt_token_ids.tolist()]
@ -74,13 +74,13 @@ def main(args: argparse.Namespace):
], ],
on_trace_ready=torch.profiler.tensorboard_trace_handler( on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir))) as p: str(profile_dir))) as p:
llm.generate(dummy_inputs, llm.generate(dummy_prompts,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=False) use_tqdm=False)
print(p.key_averages()) print(p.key_averages())
else: else:
start_time = time.perf_counter() start_time = time.perf_counter()
llm.generate(dummy_inputs, llm.generate(dummy_prompts,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=False) use_tqdm=False)
end_time = time.perf_counter() end_time = time.perf_counter()

View File

@ -8,7 +8,7 @@ Multi-Modality
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package. vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>` Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`. via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptType`.
Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
by following :ref:`this guide <adding_multimodal_plugin>`. by following :ref:`this guide <adding_multimodal_plugin>`.

View File

@ -1,7 +1,7 @@
LLM Inputs LLM Inputs
========== ==========
.. autodata:: vllm.inputs.PromptInputs .. autodata:: vllm.inputs.PromptType
.. autoclass:: vllm.inputs.TextPrompt .. autoclass:: vllm.inputs.TextPrompt
:show-inheritance: :show-inheritance:

View File

@ -27,7 +27,7 @@ The :class:`~vllm.LLM` class can be instantiated in much the same way as languag
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model. the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model.
To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`: To pass an image to the model, note the following in :class:`vllm.inputs.PromptType`:
* ``prompt``: The prompt should follow the format that is documented on HuggingFace. * ``prompt``: The prompt should follow the format that is documented on HuggingFace.
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`. * ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.

View File

@ -86,17 +86,19 @@ class MockAsyncLLMEngine(AsyncLLMEngine):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_new_requests_event(): async def test_new_requests_event():
params = SamplingParams()
engine = MockAsyncLLMEngine() engine = MockAsyncLLMEngine()
engine.start_background_loop() engine.start_background_loop()
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
assert engine.engine.step_calls == 0 assert engine.engine.step_calls == 0
await engine.add_request("1", "", None) await engine.add_request("1", "", params)
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 1 assert engine.engine.add_request_calls == 1
assert engine.engine.step_calls == 1 assert engine.engine.step_calls == 1
await engine.add_request("2", "", None) await engine.add_request("2", "", params)
engine.engine.generate("2") engine.engine.generate("2")
await asyncio.sleep(0) await asyncio.sleep(0)
await asyncio.sleep(0) await asyncio.sleep(0)
@ -111,7 +113,7 @@ async def test_new_requests_event():
await asyncio.sleep(0.001) await asyncio.sleep(0.001)
assert engine.engine.step_calls == old_step_calls assert engine.engine.step_calls == old_step_calls
await engine.add_request("3", "", None) await engine.add_request("3", "", params)
await asyncio.sleep(0.01) await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3 assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == old_step_calls + 1 assert engine.engine.step_calls == old_step_calls + 1

View File

@ -49,21 +49,6 @@ def assert_outputs_equal(o1: List[EmbeddingRequestOutput],
assert [o.outputs for o in o1] == [o.outputs for o in o2] assert [o.outputs for o in o1] == [o.outputs for o in o2]
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt', PROMPTS)
def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
pooling_params = PoolingParams()
with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.encode(prompts=prompt, pooling_params=pooling_params)
v2_output = llm.encode(prompt, pooling_params=pooling_params)
assert_outputs_equal(v1_output, v2_output)
v2_output = llm.encode({"prompt": prompt}, pooling_params=pooling_params)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) @pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
@ -79,25 +64,6 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
assert_outputs_equal(v1_output, v2_output) assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
pooling_params = PoolingParams()
with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.encode(prompts=PROMPTS, pooling_params=pooling_params)
v2_output = llm.encode(PROMPTS, pooling_params=pooling_params)
assert_outputs_equal(v1_output, v2_output)
v2_output = llm.encode(
[{
"prompt": p
} for p in PROMPTS],
pooling_params=pooling_params,
)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
pooling_params = PoolingParams() pooling_params = PoolingParams()

View File

@ -47,23 +47,6 @@ def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]):
assert [o.outputs for o in o1] == [o.outputs for o in o2] assert [o.outputs for o in o1] == [o.outputs for o in o2]
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt', PROMPTS)
def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.generate(prompts=prompt,
sampling_params=sampling_params)
v2_output = llm.generate(prompt, sampling_params=sampling_params)
assert_outputs_equal(v1_output, v2_output)
v2_output = llm.generate({"prompt": prompt},
sampling_params=sampling_params)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) @pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
@ -79,26 +62,6 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
assert_outputs_equal(v1_output, v2_output) assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.generate(prompts=PROMPTS,
sampling_params=sampling_params)
v2_output = llm.generate(PROMPTS, sampling_params=sampling_params)
assert_outputs_equal(v1_output, v2_output)
v2_output = llm.generate(
[{
"prompt": p
} for p in PROMPTS],
sampling_params=sampling_params,
)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup @pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
sampling_params = SamplingParams(temperature=0.0, top_p=1.0) sampling_params = SamplingParams(temperature=0.0, top_p=1.0)

View File

@ -61,7 +61,7 @@ async def test_evil_forward(tmp_socket):
# Throws an error in first forward pass. # Throws an error in first forward pass.
with pytest.raises(RAISED_ERROR): with pytest.raises(RAISED_ERROR):
async for _ in client.generate(inputs="Hello my name is", async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
request_id=uuid.uuid4()): request_id=uuid.uuid4()):
pass pass
@ -69,7 +69,7 @@ async def test_evil_forward(tmp_socket):
# Engine is errored, should get ENGINE_DEAD_ERROR. # Engine is errored, should get ENGINE_DEAD_ERROR.
with pytest.raises(MQEngineDeadError): with pytest.raises(MQEngineDeadError):
async for _ in client.generate(inputs="Hello my name is", async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
request_id=uuid.uuid4()): request_id=uuid.uuid4()):
pass pass
@ -118,7 +118,7 @@ async def test_failed_health_check(tmp_socket):
# Generate call should throw ENGINE_DEAD_ERROR # Generate call should throw ENGINE_DEAD_ERROR
with pytest.raises(MQEngineDeadError): with pytest.raises(MQEngineDeadError):
async for _ in client.generate(inputs="Hello my name is", async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
request_id=uuid.uuid4()): request_id=uuid.uuid4()):
pass pass
@ -160,7 +160,7 @@ async def test_failed_abort(tmp_socket):
# with reference to the original KeyError("foo") # with reference to the original KeyError("foo")
with pytest.raises(MQEngineDeadError) as execinfo: with pytest.raises(MQEngineDeadError) as execinfo:
async for _ in client.generate( async for _ in client.generate(
inputs="Hello my name is", prompt="Hello my name is",
sampling_params=SamplingParams(max_tokens=10), sampling_params=SamplingParams(max_tokens=10),
request_id=uuid.uuid4()): request_id=uuid.uuid4()):
pass pass
@ -183,7 +183,7 @@ async def test_bad_request(tmp_socket):
# Invalid request should fail, but not crash the server. # Invalid request should fail, but not crash the server.
with pytest.raises(ValueError): with pytest.raises(ValueError):
async for _ in client.generate(inputs="Hello my name is", async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
request_id="abcd-1", request_id="abcd-1",
lora_request=LoRARequest( lora_request=LoRARequest(
@ -192,7 +192,7 @@ async def test_bad_request(tmp_socket):
pass pass
# This request should be okay. # This request should be okay.
async for _ in client.generate(inputs="Hello my name is", async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
request_id="abcd-2"): request_id="abcd-2"):
pass pass

View File

@ -20,7 +20,7 @@ async def generate(
count = 0 count = 0
async for out in client.generate( async for out in client.generate(
request_id=request_id, request_id=request_id,
inputs="Hello my name is Robert and", prompt="Hello my name is Robert and",
sampling_params=SamplingParams(max_tokens=num_tokens, sampling_params=SamplingParams(max_tokens=num_tokens,
temperature=0)): temperature=0)):

View File

@ -5,7 +5,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (CompletionOutput, EmbeddingOutput, from vllm.outputs import (CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput, RequestOutput) EmbeddingRequestOutput, RequestOutput)
@ -19,7 +19,7 @@ __all__ = [
"__version_tuple__", "__version_tuple__",
"LLM", "LLM",
"ModelRegistry", "ModelRegistry",
"PromptInputs", "PromptType",
"TextPrompt", "TextPrompt",
"TokensPrompt", "TokensPrompt",
"SamplingParams", "SamplingParams",

View File

@ -2,8 +2,8 @@ import asyncio
import time import time
import weakref import weakref
from functools import partial from functools import partial
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List, from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable,
Mapping, Optional, Set, Tuple, Type, Union) List, Mapping, Optional, Set, Tuple, Type, Union, overload)
from weakref import ReferenceType from weakref import ReferenceType
import vllm.envs as envs import vllm.envs as envs
@ -17,7 +17,7 @@ from vllm.engine.metrics_types import StatLoggerBase
from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.gpu_executor import GPUExecutorAsync
from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptInputs from vllm.inputs import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
@ -28,7 +28,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import weak_bind from vllm.utils import deprecate_kwargs, weak_bind
logger = init_logger(__name__) logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
@ -402,17 +402,54 @@ class _AsyncLLMEngine(LLMEngine):
"""Stop the remote worker execution loop.""" """Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async() await self.model_executor.stop_remote_worker_execution_loop_async()
@overload # DEPRECATED
async def add_request_async( async def add_request_async(
self, self,
request_id: str, request_id: str,
inputs: PromptInputs, *,
inputs: PromptType,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
...
@overload
async def add_request_async(
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
...
@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
async def add_request_async(
self,
request_id: str,
prompt: Optional[PromptType] = None,
params: Optional[Union[SamplingParams, PoolingParams]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None: ) -> None:
"""Async version of :meth:`add_request`.""" """Async version of :meth:`add_request`."""
if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None
if lora_request is not None and not self.lora_config: if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is " raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!") "not enabled!")
@ -420,7 +457,7 @@ class _AsyncLLMEngine(LLMEngine):
arrival_time = time.time() arrival_time = time.time()
preprocessed_inputs = await self.input_preprocessor.preprocess_async( preprocessed_inputs = await self.input_preprocessor.preprocess_async(
inputs, prompt,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
@ -774,16 +811,55 @@ class AsyncLLMEngine:
# This method does not need to be async, but kept that way # This method does not need to be async, but kept that way
# for backwards compatibility. # for backwards compatibility.
async def add_request( @overload # DEPRECATED
def add_request(
self, self,
request_id: str, request_id: str,
inputs: PromptInputs, *,
inputs: PromptType,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, EmbeddingRequestOutput], None]]:
...
@overload
def add_request(
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, EmbeddingRequestOutput], None]]:
...
@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
async def add_request(
self,
request_id: str,
prompt: Optional[PromptType] = None,
params: Optional[Union[SamplingParams, PoolingParams]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None
if not self.is_running: if not self.is_running:
if self.start_engine_loop: if self.start_engine_loop:
self.start_background_loop() self.start_background_loop()
@ -797,7 +873,7 @@ class AsyncLLMEngine:
stream = self._request_tracker.add_request( stream = self._request_tracker.add_request(
request_id, request_id,
verbose=self.log_requests, verbose=self.log_requests,
inputs=inputs, prompt=prompt,
params=params, params=params,
arrival_time=arrival_time or time.time(), arrival_time=arrival_time or time.time(),
lora_request=lora_request, lora_request=lora_request,
@ -808,7 +884,7 @@ class AsyncLLMEngine:
async def generate( async def generate(
self, self,
inputs: PromptInputs, prompt: PromptType,
sampling_params: SamplingParams, sampling_params: SamplingParams,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
@ -822,8 +898,7 @@ class AsyncLLMEngine:
from the LLMEngine to the caller. from the LLMEngine to the caller.
Args: Args:
inputs: The inputs to the LLM. See prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input. for more details about the format of each input.
sampling_params: The sampling parameters of the request. sampling_params: The sampling parameters of the request.
request_id: The unique id of the request. request_id: The unique id of the request.
@ -881,7 +956,7 @@ class AsyncLLMEngine:
""" """
async for output in await self.add_request( async for output in await self.add_request(
request_id, request_id,
inputs, prompt,
sampling_params, sampling_params,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
@ -891,7 +966,7 @@ class AsyncLLMEngine:
async def encode( async def encode(
self, self,
inputs: PromptInputs, prompt: PromptType,
pooling_params: PoolingParams, pooling_params: PoolingParams,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
@ -904,8 +979,7 @@ class AsyncLLMEngine:
from the LLMEngine to the caller. from the LLMEngine to the caller.
Args: Args:
inputs: The inputs to the LLM. See prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input. for more details about the format of each input.
pooling_params: The pooling parameters of the request. pooling_params: The pooling parameters of the request.
request_id: The unique id of the request. request_id: The unique id of the request.
@ -959,7 +1033,7 @@ class AsyncLLMEngine:
""" """
async for output in await self.add_request( async for output in await self.add_request(
request_id, request_id,
inputs, prompt,
pooling_params, pooling_params,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,

View File

@ -6,7 +6,7 @@ from functools import partial
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict, from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Mapping, NamedTuple, Optional) Iterable, List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Set, Type, Union from typing import Set, Type, Union, overload
import torch import torch
from typing_extensions import TypeVar from typing_extensions import TypeVar
@ -29,7 +29,7 @@ from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
InputRegistry, LLMInputs, PromptInputs) InputRegistry, LLMInputs, PromptType)
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@ -51,7 +51,7 @@ from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup, init_tokenizer_from_configs) BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message) usage_message)
from vllm.utils import Counter, Device, weak_bind from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__) logger = init_logger(__name__)
@ -689,16 +689,51 @@ class LLMEngine:
def stop_remote_worker_execution_loop(self) -> None: def stop_remote_worker_execution_loop(self) -> None:
self.model_executor.stop_remote_worker_execution_loop() self.model_executor.stop_remote_worker_execution_loop()
@overload # DEPRECATED
def add_request( def add_request(
self, self,
request_id: str, request_id: str,
inputs: PromptInputs, *,
inputs: PromptType,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> None:
...
@overload
def add_request(
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
...
@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
def add_request(
self,
request_id: str,
prompt: Optional[PromptType] = None,
params: Optional[Union[SamplingParams, PoolingParams]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None: ) -> None:
"""Add a request to the engine's request pool. """Add a request to the engine's request pool.
@ -708,8 +743,7 @@ class LLMEngine:
Args: Args:
request_id: The unique ID of the request. request_id: The unique ID of the request.
inputs: The inputs to the LLM. See prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input. for more details about the format of each input.
params: Parameters for sampling or pooling. params: Parameters for sampling or pooling.
:class:`~vllm.SamplingParams` for text generation. :class:`~vllm.SamplingParams` for text generation.
@ -744,6 +778,10 @@ class LLMEngine:
>>> # continue the request processing >>> # continue the request processing
>>> ... >>> ...
""" """
if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None
if lora_request is not None and not self.lora_config: if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is " raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!") "not enabled!")
@ -756,7 +794,7 @@ class LLMEngine:
arrival_time = time.time() arrival_time = time.time()
preprocessed_inputs = self.input_preprocessor.preprocess( preprocessed_inputs = self.input_preprocessor.preprocess(
inputs, prompt,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,

View File

@ -1,13 +1,14 @@
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import List, Mapping, Optional, Union from typing import List, Mapping, Optional, Union, overload
from vllm import PoolingParams from vllm import PoolingParams
from vllm.inputs import PromptInputs from vllm.inputs import PromptType
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import deprecate_kwargs
VLLM_RPC_SUCCESS_STR = "SUCCESS" VLLM_RPC_SUCCESS_STR = "SUCCESS"
@ -23,13 +24,67 @@ class MQEngineDeadError(RuntimeError):
@dataclass @dataclass
class RPCProcessRequest: class RPCProcessRequest:
inputs: PromptInputs prompt: PromptType
params: Union[SamplingParams, PoolingParams] params: Union[SamplingParams, PoolingParams]
request_id: str request_id: str
lora_request: Optional[LoRARequest] = None lora_request: Optional[LoRARequest] = None
trace_headers: Optional[Mapping[str, str]] = None trace_headers: Optional[Mapping[str, str]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None
@overload # DEPRECATED
def __init__(
self,
*,
inputs: PromptType,
params: Union[SamplingParams, PoolingParams],
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
...
@overload
def __init__(
self,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
...
@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
def __init__(
self,
prompt: Optional[PromptType] = None,
params: Optional[Union[SamplingParams, PoolingParams]] = None,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
if inputs is not None:
prompt = inputs
assert (prompt is not None and params is not None
and request_id is not None)
super().__init__()
self.prompt = prompt
self.params = params
self.request_id = request_id
self.lora_request = lora_request
self.trace_headers = trace_headers
self.prompt_adapter_request = prompt_adapter_request
@dataclass @dataclass
class RPCError: class RPCError:

View File

@ -3,7 +3,7 @@ import copy
import pickle import pickle
from contextlib import contextmanager, suppress from contextlib import contextmanager, suppress
from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional, from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional,
Union) Union, overload)
import cloudpickle import cloudpickle
import zmq import zmq
@ -24,13 +24,14 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCStartupRequest, RPCStartupResponse) RPCStartupRequest, RPCStartupResponse)
# yapf: enable # yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.inputs import PromptInputs from vllm.inputs import PromptType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import deprecate_kwargs
logger = init_logger(__name__) logger = init_logger(__name__)
@ -366,14 +367,45 @@ class MQLLMEngineClient:
def dead_error(self) -> BaseException: def dead_error(self) -> BaseException:
return ENGINE_DEAD_ERROR(self._errored_with) return ENGINE_DEAD_ERROR(self._errored_with)
@overload # DEPRECATED
def generate( def generate(
self, self,
inputs: PromptInputs, *,
inputs: PromptType,
sampling_params: SamplingParams, sampling_params: SamplingParams,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> AsyncGenerator[RequestOutput, None]:
...
@overload
def generate(
self,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> AsyncGenerator[RequestOutput, None]:
...
@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
def generate(
self,
prompt: Optional[PromptType] = None,
sampling_params: Optional[SamplingParams] = None,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
*,
inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request. """Generate outputs for a request.
@ -382,8 +414,7 @@ class MQLLMEngineClient:
from the LLMEngine to the caller. from the LLMEngine to the caller.
Args: Args:
inputs: The inputs to the LLM. See prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input. for more details about the format of each input.
sampling_params: The sampling parameters of the request. sampling_params: The sampling parameters of the request.
request_id: The unique id of the request. request_id: The unique id of the request.
@ -392,17 +423,51 @@ class MQLLMEngineClient:
prompt_adapter_request: Prompt Adapter request to use prompt_adapter_request: Prompt Adapter request to use
for generation, if any. for generation, if any.
""" """
return self._process_request(inputs, sampling_params, request_id, if inputs is not None:
prompt = inputs
assert (prompt is not None and sampling_params is not None
and request_id is not None)
return self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers, lora_request, trace_headers,
prompt_adapter_request) prompt_adapter_request)
@overload # DEPRECATED
def encode( def encode(
self, self,
inputs: PromptInputs, *,
inputs: PromptType,
pooling_params: PoolingParams, pooling_params: PoolingParams,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
...
@overload
def encode(
self,
prompt: PromptType,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
...
@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
def encode(
self,
prompt: Optional[PromptType] = None,
pooling_params: Optional[PoolingParams] = None,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
*,
inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[EmbeddingRequestOutput, None]: ) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model. """Generate outputs for a request from an embedding model.
@ -411,8 +476,7 @@ class MQLLMEngineClient:
from the LLMEngine to the caller. from the LLMEngine to the caller.
Args: Args:
inputs: The inputs to the LLM. See prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input. for more details about the format of each input.
pooling_params: The pooling parameters of the request. pooling_params: The pooling parameters of the request.
request_id: The unique id of the request. request_id: The unique id of the request.
@ -423,12 +487,17 @@ class MQLLMEngineClient:
The output `EmbeddingRequestOutput` objects from the LLMEngine The output `EmbeddingRequestOutput` objects from the LLMEngine
for the request. for the request.
""" """
return self._process_request(inputs, pooling_params, request_id, if inputs is not None:
prompt = inputs
assert (prompt is not None and pooling_params is not None
and request_id is not None)
return self._process_request(prompt, pooling_params, request_id,
lora_request, trace_headers) lora_request, trace_headers)
async def _process_request( async def _process_request(
self, self,
inputs: PromptInputs, prompt: PromptType,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
@ -461,7 +530,7 @@ class MQLLMEngineClient:
request_bytes = pickle.dumps( request_bytes = pickle.dumps(
RPCProcessRequest( RPCProcessRequest(
inputs=inputs, prompt=prompt,
params=params, params=params,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,

View File

@ -271,7 +271,7 @@ class MQLLMEngine:
try: try:
self.engine.add_request( self.engine.add_request(
request_id=request_id, request_id=request_id,
inputs=request.inputs, prompt=request.prompt,
params=request.params, params=request.params,
lora_request=request.lora_request, lora_request=request.lora_request,
trace_headers=request.trace_headers, trace_headers=request.trace_headers,

View File

@ -3,7 +3,7 @@ from typing import (AsyncGenerator, List, Mapping, Optional, Protocol,
from vllm.config import DecodingConfig, ModelConfig from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptInputs from vllm.inputs.data import PromptType
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
@ -35,19 +35,19 @@ class EngineClient(Protocol):
def generate( def generate(
self, self,
inputs: PromptInputs, prompt: PromptType,
sampling_params: SamplingParams, sampling_params: SamplingParams,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
"""Generates outputs for a request""" """Generate outputs for a request."""
... ...
def encode( def encode(
self, self,
inputs: PromptInputs, prompt: PromptType,
pooling_params: PoolingParams, pooling_params: PoolingParams,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,

View File

@ -12,7 +12,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template, apply_hf_chat_template,
apply_mistral_chat_template, apply_mistral_chat_template,
parse_chat_messages) parse_chat_messages)
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@ -293,8 +293,8 @@ class LLM:
@overload @overload
def generate( def generate(
self, self,
inputs: Union[PromptInputs, Sequence[PromptInputs]], prompts: Union[PromptType, Sequence[PromptType]],
/, # We may enable `inputs` keyword after removing the old API /,
*, *,
sampling_params: Optional[Union[SamplingParams, sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None, Sequence[SamplingParams]]] = None,
@ -304,14 +304,13 @@ class LLM:
... ...
@deprecate_kwargs( @deprecate_kwargs(
"prompts",
"prompt_token_ids", "prompt_token_ids",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY, is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter instead.", additional_message="Please use the 'prompts' parameter instead.",
) )
def generate( def generate(
self, self,
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], prompts: Union[Union[PromptType, Sequence[PromptType]],
Optional[Union[str, List[str]]]] = None, Optional[Union[str, List[str]]]] = None,
sampling_params: Optional[Union[SamplingParams, sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None, Sequence[SamplingParams]]] = None,
@ -330,7 +329,9 @@ class LLM:
into a single list and pass it to this method. into a single list and pass it to this method.
Args: Args:
inputs: A list of inputs to generate completions for. prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
sampling_params: The sampling parameters for text generation. If sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters. None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt. When it is a single value, it is applied to every prompt.
@ -358,12 +359,13 @@ class LLM:
"models (XForCausalLM, XForConditionalGeneration).") "models (XForCausalLM, XForConditionalGeneration).")
if prompt_token_ids is not None: if prompt_token_ids is not None:
inputs = self._convert_v1_inputs( parsed_prompts = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts), prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
) )
else: else:
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
prompts)
if isinstance(guided_options_request, dict): if isinstance(guided_options_request, dict):
if len(guided_options_request) > 1: if len(guided_options_request) > 1:
@ -378,7 +380,7 @@ class LLM:
sampling_params = SamplingParams() sampling_params = SamplingParams()
self._validate_and_add_requests( self._validate_and_add_requests(
inputs=inputs, prompts=parsed_prompts,
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
@ -648,8 +650,8 @@ class LLM:
@overload @overload
def encode( def encode(
self, self,
inputs: Union[PromptInputs, Sequence[PromptInputs]], prompts: Union[PromptType, Sequence[PromptType]],
/, # We may enable `inputs` keyword after removing the old API /,
*, *,
pooling_params: Optional[Union[PoolingParams, pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None, Sequence[PoolingParams]]] = None,
@ -659,14 +661,13 @@ class LLM:
... ...
@deprecate_kwargs( @deprecate_kwargs(
"prompts",
"prompt_token_ids", "prompt_token_ids",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY, is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter instead.", additional_message="Please use the 'prompts' parameter instead.",
) )
def encode( def encode(
self, self,
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], prompts: Union[Union[PromptType, Sequence[PromptType]],
Optional[Union[str, List[str]]]] = None, Optional[Union[str, List[str]]]] = None,
pooling_params: Optional[Union[PoolingParams, pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None, Sequence[PoolingParams]]] = None,
@ -682,9 +683,9 @@ class LLM:
into a single list and pass it to this method. into a single list and pass it to this method.
Args: Args:
inputs: The inputs to the LLM. You may pass a sequence of inputs for prompts: The prompts to the LLM. You may pass a sequence of prompts
batch inference. See :class:`~vllm.inputs.PromptInputs` for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input. for more details about the format of each prompts.
pooling_params: The pooling parameters for pooling. If None, we pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters. use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar. use_tqdm: Whether to use tqdm to display the progress bar.
@ -707,19 +708,20 @@ class LLM:
) )
if prompt_token_ids is not None: if prompt_token_ids is not None:
inputs = self._convert_v1_inputs( parsed_prompts = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts), prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
) )
else: else:
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
prompts)
if pooling_params is None: if pooling_params is None:
# Use default pooling params. # Use default pooling params.
pooling_params = PoolingParams() pooling_params = PoolingParams()
self._validate_and_add_requests( self._validate_and_add_requests(
inputs=inputs, prompts=parsed_prompts,
params=pooling_params, params=pooling_params,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
@ -763,9 +765,9 @@ class LLM:
raise ValueError("Either prompts or prompt_token_ids must be " raise ValueError("Either prompts or prompt_token_ids must be "
"provided.") "provided.")
inputs: List[PromptInputs] = [] parsed_prompts: List[PromptType] = []
for i in range(num_requests): for i in range(num_requests):
item: PromptInputs item: PromptType
if prompts is not None: if prompts is not None:
item = TextPrompt(prompt=prompts[i]) item = TextPrompt(prompt=prompts[i])
@ -774,13 +776,13 @@ class LLM:
else: else:
raise AssertionError raise AssertionError
inputs.append(item) parsed_prompts.append(item)
return inputs return parsed_prompts
def _validate_and_add_requests( def _validate_and_add_requests(
self, self,
inputs: Union[PromptInputs, Sequence[PromptInputs]], prompts: Union[PromptType, Sequence[PromptType]],
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]], Sequence[PoolingParams]],
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
@ -788,11 +790,11 @@ class LLM:
guided_options: Optional[GuidedDecodingRequest] = None, guided_options: Optional[GuidedDecodingRequest] = None,
priority: Optional[List[int]] = None, priority: Optional[List[int]] = None,
) -> None: ) -> None:
if isinstance(inputs, (str, dict)): if isinstance(prompts, (str, dict)):
# Convert a single prompt to a list. # Convert a single prompt to a list.
inputs = [inputs] prompts = [prompts]
num_requests = len(inputs) num_requests = len(prompts)
if isinstance(params, list) and len(params) != num_requests: if isinstance(params, list) and len(params) != num_requests:
raise ValueError("The lengths of prompts and params " raise ValueError("The lengths of prompts and params "
"must be the same.") "must be the same.")
@ -809,9 +811,9 @@ class LLM:
sp.output_kind = RequestOutputKind.FINAL_ONLY sp.output_kind = RequestOutputKind.FINAL_ONLY
# Add requests to the engine. # Add requests to the engine.
for i, request_inputs in enumerate(inputs): for i, prompt in enumerate(prompts):
self._add_request( self._add_request(
request_inputs, prompt,
params[i] if isinstance(params, Sequence) else params, params[i] if isinstance(params, Sequence) else params,
lora_request=lora_request[i] if isinstance( lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request, lora_request, Sequence) else lora_request,
@ -821,7 +823,7 @@ class LLM:
def _add_request( def _add_request(
self, self,
inputs: PromptInputs, prompt: PromptType,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@ -830,7 +832,7 @@ class LLM:
request_id = str(next(self.request_counter)) request_id = str(next(self.request_counter))
self.llm_engine.add_request( self.llm_engine.add_request(
request_id, request_id,
inputs, prompt,
params, params,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,

View File

@ -1,5 +1,5 @@
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, LLMInputs, PromptType, SingletonPrompt, TextPrompt,
TokensPrompt, build_explicit_enc_dec_prompt, TokensPrompt, build_explicit_enc_dec_prompt,
to_enc_dec_tuple_list, zip_enc_dec_prompts) to_enc_dec_tuple_list, zip_enc_dec_prompts)
from .registry import InputContext, InputRegistry from .registry import InputContext, InputRegistry
@ -16,8 +16,8 @@ See also:
__all__ = [ __all__ = [
"TextPrompt", "TextPrompt",
"TokensPrompt", "TokensPrompt",
"PromptInputs", "PromptType",
"SingletonPromptInputs", "SingletonPrompt",
"ExplicitEncoderDecoderPrompt", "ExplicitEncoderDecoderPrompt",
"LLMInputs", "LLMInputs",
"EncoderDecoderLLMInputs", "EncoderDecoderLLMInputs",
@ -28,3 +28,17 @@ __all__ = [
"InputContext", "InputContext",
"InputRegistry", "InputRegistry",
] ]
def __getattr__(name: str):
if name == "PromptInput":
import warnings
msg = ("PromptInput has been renamed to PromptType. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return PromptType
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -33,7 +33,7 @@ class TokensPrompt(TypedDict):
""" """
SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt] SingletonPrompt = Union[str, TextPrompt, TokensPrompt]
""" """
Set of possible schemas for a single LLM input: Set of possible schemas for a single LLM input:
@ -46,7 +46,7 @@ which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder the user desires to express both the encoder & decoder
prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt` prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt`
A prompt of type :class:`SingletonPromptInputs` may be employed A prompt of type :class:`SingletonPrompt` may be employed
as (1) input to a decoder-only model, (2) input to as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or where the decoder-prompt is not specified explicitly, or
@ -55,33 +55,33 @@ more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt`
""" """
_T1_co = TypeVar("_T1_co", _T1_co = TypeVar("_T1_co",
bound=SingletonPromptInputs, bound=SingletonPrompt,
default=SingletonPromptInputs, default=SingletonPrompt,
covariant=True) covariant=True)
_T2_co = TypeVar("_T2_co", _T2_co = TypeVar("_T2_co",
bound=SingletonPromptInputs, bound=SingletonPrompt,
default=SingletonPromptInputs, default=SingletonPrompt,
covariant=True) covariant=True)
# TODO: Make fields ReadOnly once mypy supports it # TODO: Make fields ReadOnly once mypy supports it
class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]): class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
"""Represents an encoder/decoder model input prompt, """
comprising an explicit encoder prompt and a Represents an encoder/decoder model input prompt,
decoder prompt. comprising an explicit encoder prompt and a decoder prompt.
The encoder and decoder prompts, respectively, The encoder and decoder prompts, respectively,
may formatted according to any of the may formatted according to any of the
:class:`SingletonPromptInputs` schemas, and are not :class:`SingletonPrompt` schemas, and are not
required to have the same schema. required to have the same schema.
Only the encoder prompt may have multi-modal data. Only the encoder prompt may have multi-modal data.
Note that an :class:`ExplicitEncoderDecoderPrompt` may not Note that an :class:`ExplicitEncoderDecoderPrompt` may not
be used as an input to a decoder-only model, be used as an input to a decoder-only model,
and that the `encoder_prompt` and `decoder_prompt` and that the :code:`encoder_prompt` and :code:`decoder_prompt`
fields of this data structure themselves must be fields of this data structure themselves must be
:class:`SingletonPromptInputs` instances. :class:`SingletonPrompt` instances.
""" """
encoder_prompt: _T1_co encoder_prompt: _T1_co
@ -89,7 +89,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
decoder_prompt: Optional[_T2_co] decoder_prompt: Optional[_T2_co]
PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt] PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt]
""" """
Set of possible schemas for an LLM input, including Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types: both decoder-only and encoder/decoder input types:
@ -140,12 +140,8 @@ class EncoderDecoderLLMInputs(LLMInputs):
""" """
_T1 = TypeVar("_T1", _T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt)
bound=SingletonPromptInputs, _T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)
default=SingletonPromptInputs)
_T2 = TypeVar("_T2",
bound=SingletonPromptInputs,
default=SingletonPromptInputs)
def build_explicit_enc_dec_prompt( def build_explicit_enc_dec_prompt(
@ -176,3 +172,17 @@ def to_enc_dec_tuple_list(
return [(enc_dec_prompt["encoder_prompt"], return [(enc_dec_prompt["encoder_prompt"],
enc_dec_prompt["decoder_prompt"]) enc_dec_prompt["decoder_prompt"])
for enc_dec_prompt in enc_dec_prompts] for enc_dec_prompt in enc_dec_prompts]
def __getattr__(name: str):
if name == "PromptInput":
import warnings
msg = ("PromptInput has been renamed to PromptType. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return PromptType
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -5,7 +5,7 @@ from typing_extensions import TypeIs
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt, LLMInputs, PromptType, SingletonPrompt, TextPrompt,
TokensPrompt) TokensPrompt)
@ -81,23 +81,23 @@ class ParsedTokensPrompt(TypedDict):
def parse_singleton_prompt( def parse_singleton_prompt(
inputs: SingletonPromptInputs, prompt: SingletonPrompt,
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]: ) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]:
if isinstance(inputs, str): if isinstance(prompt, str):
return ParsedStrPrompt(type="str", content=inputs) return ParsedStrPrompt(type="str", content=prompt)
elif isinstance(inputs, dict): elif isinstance(prompt, dict):
if "prompt_token_ids" in inputs: if "prompt_token_ids" in prompt:
return ParsedTokensPrompt(type="tokens", return ParsedTokensPrompt(type="tokens",
content=inputs) # type: ignore content=prompt) # type: ignore
elif "prompt" in inputs: elif "prompt" in prompt:
return ParsedTextPrompt(type="text", content=inputs) return ParsedTextPrompt(type="text", content=prompt)
raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt") raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt")
def is_explicit_encoder_decoder_prompt( def is_explicit_encoder_decoder_prompt(
inputs: PromptInputs) -> TypeIs[ExplicitEncoderDecoderPrompt]: prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
return isinstance(inputs, dict) and "encoder_prompt" in inputs return isinstance(prompt, dict) and "encoder_prompt" in prompt
def is_valid_encoder_decoder_llm_inputs( def is_valid_encoder_decoder_llm_inputs(

View File

@ -9,8 +9,8 @@ from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs, from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptType,
SingletonPromptInputs) SingletonPrompt)
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
if TYPE_CHECKING: if TYPE_CHECKING:
@ -206,7 +206,7 @@ class InputPreprocessor:
def _extract_prompt_components( def _extract_prompt_components(
self, self,
inputs: SingletonPromptInputs, prompt: SingletonPrompt,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
) -> PromptComponents: ) -> PromptComponents:
@ -216,7 +216,7 @@ class InputPreprocessor:
Arguments: Arguments:
* request_id * request_id
* inputs: single encoder or decoder input prompt * prompt: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts * lora_request: this is only valid for decoder prompts
Returns: Returns:
@ -226,24 +226,24 @@ class InputPreprocessor:
* multi_modal_data * multi_modal_data
''' '''
parsed = parse_singleton_prompt(inputs) parsed = parse_singleton_prompt(prompt)
if parsed["type"] == "str": if parsed["type"] == "str":
prompt = parsed["content"] prompt_text = parsed["content"]
prompt_token_ids = self._tokenize_prompt( prompt_token_ids = self._tokenize_prompt(
prompt, prompt_text,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
multi_modal_data = None multi_modal_data = None
elif parsed["type"] == "tokens": elif parsed["type"] == "tokens":
prompt = None prompt_text = None
prompt_token_ids = parsed["content"]["prompt_token_ids"] prompt_token_ids = parsed["content"]["prompt_token_ids"]
multi_modal_data = parsed["content"].get("multi_modal_data") multi_modal_data = parsed["content"].get("multi_modal_data")
elif parsed["type"] == "text": elif parsed["type"] == "text":
prompt = parsed["content"]["prompt"] prompt_text = parsed["content"]["prompt"]
prompt_token_ids = self._tokenize_prompt( prompt_token_ids = self._tokenize_prompt(
prompt, prompt_text,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
@ -251,33 +251,33 @@ class InputPreprocessor:
else: else:
assert_never(parsed) assert_never(parsed)
return prompt, prompt_token_ids, multi_modal_data return prompt_text, prompt_token_ids, multi_modal_data
async def _extract_prompt_components_async( async def _extract_prompt_components_async(
self, self,
inputs: SingletonPromptInputs, prompt: SingletonPrompt,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
) -> PromptComponents: ) -> PromptComponents:
"""Async version of :meth:`_extract_prompt_components`.""" """Async version of :meth:`_extract_prompt_components`."""
parsed = parse_singleton_prompt(inputs) parsed = parse_singleton_prompt(prompt)
if parsed["type"] == "str": if parsed["type"] == "str":
prompt = parsed["content"] prompt_text = parsed["content"]
prompt_token_ids = await self._tokenize_prompt_async( prompt_token_ids = await self._tokenize_prompt_async(
prompt, prompt_text,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
multi_modal_data = None multi_modal_data = None
elif parsed["type"] == "tokens": elif parsed["type"] == "tokens":
prompt = None prompt_text = None
prompt_token_ids = parsed["content"]["prompt_token_ids"] prompt_token_ids = parsed["content"]["prompt_token_ids"]
multi_modal_data = parsed["content"].get("multi_modal_data") multi_modal_data = parsed["content"].get("multi_modal_data")
elif parsed["type"] == "text": elif parsed["type"] == "text":
prompt = parsed["content"]["prompt"] prompt_text = parsed["content"]["prompt"]
prompt_token_ids = await self._tokenize_prompt_async( prompt_token_ids = await self._tokenize_prompt_async(
prompt, prompt_text,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
@ -285,7 +285,7 @@ class InputPreprocessor:
else: else:
assert_never(parsed) assert_never(parsed)
return prompt, prompt_token_ids, multi_modal_data return prompt_text, prompt_token_ids, multi_modal_data
def _build_enc_dec_llm_inputs( def _build_enc_dec_llm_inputs(
self, self,
@ -311,7 +311,7 @@ class InputPreprocessor:
def _process_encoder_decoder_prompt( def _process_encoder_decoder_prompt(
self, self,
inputs: PromptInputs, prompt: PromptType,
request_id: str, request_id: str,
) -> EncoderDecoderLLMInputs: ) -> EncoderDecoderLLMInputs:
''' '''
@ -339,7 +339,7 @@ class InputPreprocessor:
Arguments: Arguments:
* inputs: an input prompt * prompt: an input prompt
* request_id * request_id
Returns: Returns:
@ -350,13 +350,13 @@ class InputPreprocessor:
encoder_comps: PromptComponents encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents decoder_comps: DecoderPromptComponents
if is_explicit_encoder_decoder_prompt(inputs): if is_explicit_encoder_decoder_prompt(prompt):
encoder_comps = self._extract_prompt_components( encoder_comps = self._extract_prompt_components(
inputs["encoder_prompt"], prompt["encoder_prompt"],
request_id=request_id, request_id=request_id,
) )
if (decoder_input := inputs["decoder_prompt"]) is None: if (decoder_input := prompt["decoder_prompt"]) is None:
decoder_comps = None, None, None decoder_comps = None, None, None
else: else:
decoder_comps = self._extract_prompt_components( decoder_comps = self._extract_prompt_components(
@ -365,7 +365,7 @@ class InputPreprocessor:
) )
else: else:
encoder_comps = self._extract_prompt_components( encoder_comps = self._extract_prompt_components(
inputs, prompt,
request_id=request_id, request_id=request_id,
) )
@ -375,20 +375,20 @@ class InputPreprocessor:
async def _process_encoder_decoder_prompt_async( async def _process_encoder_decoder_prompt_async(
self, self,
inputs: PromptInputs, prompt: PromptType,
request_id: str, request_id: str,
) -> EncoderDecoderLLMInputs: ) -> EncoderDecoderLLMInputs:
"""Async version of :meth:`_process_encoder_decoder_prompt`.""" """Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_comps: PromptComponents encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents decoder_comps: DecoderPromptComponents
if is_explicit_encoder_decoder_prompt(inputs): if is_explicit_encoder_decoder_prompt(prompt):
encoder_task = self._extract_prompt_components_async( encoder_task = self._extract_prompt_components_async(
inputs["encoder_prompt"], prompt["encoder_prompt"],
request_id=request_id, request_id=request_id,
) )
if (decoder_input := inputs["decoder_prompt"]) is None: if (decoder_input := prompt["decoder_prompt"]) is None:
encoder_comps = await encoder_task encoder_comps = await encoder_task
decoder_comps = None, None, None decoder_comps = None, None, None
else: else:
@ -401,7 +401,7 @@ class InputPreprocessor:
encoder_task, decoder_task) encoder_task, decoder_task)
else: else:
encoder_comps = await self._extract_prompt_components_async( encoder_comps = await self._extract_prompt_components_async(
inputs, prompt,
request_id=request_id, request_id=request_id,
) )
@ -425,7 +425,7 @@ class InputPreprocessor:
def _process_decoder_only_prompt( def _process_decoder_only_prompt(
self, self,
inputs: SingletonPromptInputs, prompt: SingletonPrompt,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@ -436,7 +436,7 @@ class InputPreprocessor:
Arguments: Arguments:
* inputs: input prompt * prompt: input prompt
* request_id * request_id
* lora_request * lora_request
* prompt_adapter_request * prompt_adapter_request
@ -447,7 +447,7 @@ class InputPreprocessor:
''' '''
prompt_comps = self._extract_prompt_components( prompt_comps = self._extract_prompt_components(
inputs, prompt,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
@ -459,14 +459,14 @@ class InputPreprocessor:
async def _process_decoder_only_prompt_async( async def _process_decoder_only_prompt_async(
self, self,
inputs: SingletonPromptInputs, prompt: SingletonPrompt,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs: ) -> LLMInputs:
"""Async version of :meth:`_process_decoder_only_prompt`.""" """Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps = await self._extract_prompt_components_async( prompt_comps = await self._extract_prompt_components_async(
inputs, prompt,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
@ -478,7 +478,7 @@ class InputPreprocessor:
def preprocess( def preprocess(
self, self,
inputs: PromptInputs, prompt: PromptType,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@ -488,17 +488,17 @@ class InputPreprocessor:
# Encoder-decoder model requires special mapping of # Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder # input prompts to encoder & decoder
return self._process_encoder_decoder_prompt( return self._process_encoder_decoder_prompt(
inputs, prompt,
request_id=request_id, request_id=request_id,
) )
if is_explicit_encoder_decoder_prompt(inputs): if is_explicit_encoder_decoder_prompt(prompt):
raise ValueError("Cannot pass encoder-decoder prompt " raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models") "to decoder-only models")
# Decoder-only operation # Decoder-only operation
return self._process_decoder_only_prompt( return self._process_decoder_only_prompt(
inputs, prompt,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
@ -506,7 +506,7 @@ class InputPreprocessor:
async def preprocess_async( async def preprocess_async(
self, self,
inputs: PromptInputs, prompt: PromptType,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@ -516,17 +516,17 @@ class InputPreprocessor:
# Encoder-decoder model requires special mapping of # Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder # input prompts to encoder & decoder
return await self._process_encoder_decoder_prompt_async( return await self._process_encoder_decoder_prompt_async(
inputs, prompt,
request_id=request_id, request_id=request_id,
) )
if is_explicit_encoder_decoder_prompt(inputs): if is_explicit_encoder_decoder_prompt(prompt):
raise ValueError("Cannot pass encoder-decoder prompt " raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models") "to decoder-only models")
# Decoder-only operation # Decoder-only operation
return await self._process_decoder_only_prompt_async( return await self._process_decoder_only_prompt_async(
inputs, prompt,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,