Revert "[Core] Rename PromptInputs to PromptType, and inputs to prompt" (#8750)

This commit is contained in:
Simon Mo 2024-09-23 22:45:20 -07:00 committed by GitHub
parent 0250dd68c5
commit 3185fb0cca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 162 additions and 157 deletions

View File

@ -11,7 +11,7 @@ from tqdm import tqdm
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs
from vllm.inputs import PromptType from vllm.inputs import PromptInputs
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
@ -61,7 +61,7 @@ def main(args: argparse.Namespace):
dummy_prompt_token_ids = np.random.randint(10000, dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size, size=(args.batch_size,
args.input_len)) args.input_len))
dummy_prompts: List[PromptType] = [{ dummy_inputs: List[PromptInputs] = [{
"prompt_token_ids": batch "prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()] } for batch in dummy_prompt_token_ids.tolist()]
@ -74,13 +74,13 @@ def main(args: argparse.Namespace):
], ],
on_trace_ready=torch.profiler.tensorboard_trace_handler( on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir))) as p: str(profile_dir))) as p:
llm.generate(dummy_prompts, llm.generate(dummy_inputs,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=False) use_tqdm=False)
print(p.key_averages()) print(p.key_averages())
else: else:
start_time = time.perf_counter() start_time = time.perf_counter()
llm.generate(dummy_prompts, llm.generate(dummy_inputs,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=False) use_tqdm=False)
end_time = time.perf_counter() end_time = time.perf_counter()

View File

@ -8,7 +8,7 @@ Multi-Modality
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package. vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>` Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptType`. via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`.
Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
by following :ref:`this guide <adding_multimodal_plugin>`. by following :ref:`this guide <adding_multimodal_plugin>`.

View File

@ -1,7 +1,7 @@
LLM Inputs LLM Inputs
========== ==========
.. autodata:: vllm.inputs.PromptType .. autodata:: vllm.inputs.PromptInputs
.. autoclass:: vllm.inputs.TextPrompt .. autoclass:: vllm.inputs.TextPrompt
:show-inheritance: :show-inheritance:

View File

@ -27,7 +27,7 @@ The :class:`~vllm.LLM` class can be instantiated in much the same way as languag
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model. the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model.
To pass an image to the model, note the following in :class:`vllm.inputs.PromptType`: To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`:
* ``prompt``: The prompt should follow the format that is documented on HuggingFace. * ``prompt``: The prompt should follow the format that is documented on HuggingFace.
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`. * ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.

View File

@ -61,7 +61,7 @@ async def test_evil_forward(tmp_socket):
# Throws an error in first forward pass. # Throws an error in first forward pass.
with pytest.raises(RAISED_ERROR): with pytest.raises(RAISED_ERROR):
async for _ in client.generate(prompt="Hello my name is", async for _ in client.generate(inputs="Hello my name is",
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
request_id=uuid.uuid4()): request_id=uuid.uuid4()):
pass pass
@ -69,7 +69,7 @@ async def test_evil_forward(tmp_socket):
# Engine is errored, should get ENGINE_DEAD_ERROR. # Engine is errored, should get ENGINE_DEAD_ERROR.
with pytest.raises(MQEngineDeadError): with pytest.raises(MQEngineDeadError):
async for _ in client.generate(prompt="Hello my name is", async for _ in client.generate(inputs="Hello my name is",
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
request_id=uuid.uuid4()): request_id=uuid.uuid4()):
pass pass
@ -118,7 +118,7 @@ async def test_failed_health_check(tmp_socket):
# Generate call should throw ENGINE_DEAD_ERROR # Generate call should throw ENGINE_DEAD_ERROR
with pytest.raises(MQEngineDeadError): with pytest.raises(MQEngineDeadError):
async for _ in client.generate(prompt="Hello my name is", async for _ in client.generate(inputs="Hello my name is",
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
request_id=uuid.uuid4()): request_id=uuid.uuid4()):
pass pass
@ -165,7 +165,7 @@ async def test_failed_abort(tmp_socket):
# with reference to the original KeyError("foo") # with reference to the original KeyError("foo")
with pytest.raises(MQEngineDeadError) as execinfo: with pytest.raises(MQEngineDeadError) as execinfo:
async for _ in client.generate( async for _ in client.generate(
prompt="Hello my name is", inputs="Hello my name is",
sampling_params=SamplingParams(max_tokens=2000), sampling_params=SamplingParams(max_tokens=2000),
request_id=uuid.uuid4()): request_id=uuid.uuid4()):
pass pass
@ -190,7 +190,7 @@ async def test_bad_request(tmp_socket):
# Invalid request should fail, but not crash the server. # Invalid request should fail, but not crash the server.
with pytest.raises(ValueError): with pytest.raises(ValueError):
async for _ in client.generate(prompt="Hello my name is", async for _ in client.generate(inputs="Hello my name is",
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
request_id="abcd-1", request_id="abcd-1",
lora_request=LoRARequest( lora_request=LoRARequest(
@ -199,7 +199,7 @@ async def test_bad_request(tmp_socket):
pass pass
# This request should be okay. # This request should be okay.
async for _ in client.generate(prompt="Hello my name is", async for _ in client.generate(inputs="Hello my name is",
sampling_params=SamplingParams(), sampling_params=SamplingParams(),
request_id="abcd-2"): request_id="abcd-2"):
pass pass

View File

@ -20,7 +20,7 @@ async def generate(
count = 0 count = 0
async for out in client.generate( async for out in client.generate(
request_id=request_id, request_id=request_id,
prompt="Hello my name is Robert and", inputs="Hello my name is Robert and",
sampling_params=SamplingParams(max_tokens=num_tokens, sampling_params=SamplingParams(max_tokens=num_tokens,
temperature=0)): temperature=0)):

View File

@ -5,7 +5,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType, TextPrompt, TokensPrompt from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (CompletionOutput, EmbeddingOutput, from vllm.outputs import (CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput, RequestOutput) EmbeddingRequestOutput, RequestOutput)
@ -19,7 +19,7 @@ __all__ = [
"__version_tuple__", "__version_tuple__",
"LLM", "LLM",
"ModelRegistry", "ModelRegistry",
"PromptType", "PromptInputs",
"TextPrompt", "TextPrompt",
"TokensPrompt", "TokensPrompt",
"SamplingParams", "SamplingParams",

View File

@ -17,7 +17,7 @@ from vllm.engine.metrics_types import StatLoggerBase
from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.gpu_executor import GPUExecutorAsync
from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType from vllm.inputs import PromptInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
@ -405,7 +405,7 @@ class _AsyncLLMEngine(LLMEngine):
async def add_request_async( async def add_request_async(
self, self,
request_id: str, request_id: str,
prompt: PromptType, inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
@ -420,7 +420,7 @@ class _AsyncLLMEngine(LLMEngine):
arrival_time = time.time() arrival_time = time.time()
preprocessed_inputs = await self.input_preprocessor.preprocess_async( preprocessed_inputs = await self.input_preprocessor.preprocess_async(
prompt, inputs,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
@ -777,7 +777,7 @@ class AsyncLLMEngine:
async def add_request( async def add_request(
self, self,
request_id: str, request_id: str,
prompt: PromptType, inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
@ -797,7 +797,7 @@ class AsyncLLMEngine:
stream = self._request_tracker.add_request( stream = self._request_tracker.add_request(
request_id, request_id,
verbose=self.log_requests, verbose=self.log_requests,
prompt=prompt, inputs=inputs,
params=params, params=params,
arrival_time=arrival_time or time.time(), arrival_time=arrival_time or time.time(),
lora_request=lora_request, lora_request=lora_request,
@ -808,7 +808,7 @@ class AsyncLLMEngine:
async def generate( async def generate(
self, self,
prompt: PromptType, inputs: PromptInputs,
sampling_params: SamplingParams, sampling_params: SamplingParams,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
@ -822,7 +822,8 @@ class AsyncLLMEngine:
from the LLMEngine to the caller. from the LLMEngine to the caller.
Args: Args:
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input. for more details about the format of each input.
sampling_params: The sampling parameters of the request. sampling_params: The sampling parameters of the request.
request_id: The unique id of the request. request_id: The unique id of the request.
@ -880,7 +881,7 @@ class AsyncLLMEngine:
""" """
async for output in await self.add_request( async for output in await self.add_request(
request_id, request_id,
prompt, inputs,
sampling_params, sampling_params,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
@ -890,7 +891,7 @@ class AsyncLLMEngine:
async def encode( async def encode(
self, self,
prompt: PromptType, inputs: PromptInputs,
pooling_params: PoolingParams, pooling_params: PoolingParams,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
@ -903,7 +904,8 @@ class AsyncLLMEngine:
from the LLMEngine to the caller. from the LLMEngine to the caller.
Args: Args:
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input. for more details about the format of each input.
pooling_params: The pooling parameters of the request. pooling_params: The pooling parameters of the request.
request_id: The unique id of the request. request_id: The unique id of the request.
@ -957,7 +959,7 @@ class AsyncLLMEngine:
""" """
async for output in await self.add_request( async for output in await self.add_request(
request_id, request_id,
prompt, inputs,
pooling_params, pooling_params,
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,

View File

@ -29,7 +29,7 @@ from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
InputRegistry, LLMInputs, PromptType) InputRegistry, LLMInputs, PromptInputs)
from vllm.inputs.preprocess import InputPreprocessor from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@ -689,7 +689,7 @@ class LLMEngine:
def add_request( def add_request(
self, self,
request_id: str, request_id: str,
prompt: PromptType, inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None, arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
@ -704,7 +704,8 @@ class LLMEngine:
Args: Args:
request_id: The unique ID of the request. request_id: The unique ID of the request.
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input. for more details about the format of each input.
params: Parameters for sampling or pooling. params: Parameters for sampling or pooling.
:class:`~vllm.SamplingParams` for text generation. :class:`~vllm.SamplingParams` for text generation.
@ -744,7 +745,7 @@ class LLMEngine:
arrival_time = time.time() arrival_time = time.time()
preprocessed_inputs = self.input_preprocessor.preprocess( preprocessed_inputs = self.input_preprocessor.preprocess(
prompt, inputs,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,

View File

@ -3,7 +3,7 @@ from enum import Enum
from typing import List, Mapping, Optional, Union from typing import List, Mapping, Optional, Union
from vllm import PoolingParams from vllm import PoolingParams
from vllm.inputs import PromptType from vllm.inputs import PromptInputs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
@ -23,7 +23,7 @@ class MQEngineDeadError(RuntimeError):
@dataclass @dataclass
class RPCProcessRequest: class RPCProcessRequest:
prompt: PromptType inputs: PromptInputs
params: Union[SamplingParams, PoolingParams] params: Union[SamplingParams, PoolingParams]
request_id: str request_id: str
lora_request: Optional[LoRARequest] = None lora_request: Optional[LoRARequest] = None

View File

@ -25,7 +25,7 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCStartupResponse) RPCStartupResponse)
# yapf: enable # yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.inputs import PromptType from vllm.inputs import PromptInputs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
@ -375,7 +375,7 @@ class MQLLMEngineClient:
def generate( def generate(
self, self,
prompt: PromptType, inputs: PromptInputs,
sampling_params: SamplingParams, sampling_params: SamplingParams,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
@ -389,7 +389,8 @@ class MQLLMEngineClient:
from the LLMEngine to the caller. from the LLMEngine to the caller.
Args: Args:
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input. for more details about the format of each input.
sampling_params: The sampling parameters of the request. sampling_params: The sampling parameters of the request.
request_id: The unique id of the request. request_id: The unique id of the request.
@ -398,13 +399,13 @@ class MQLLMEngineClient:
prompt_adapter_request: Prompt Adapter request to use prompt_adapter_request: Prompt Adapter request to use
for generation, if any. for generation, if any.
""" """
return self._process_request(prompt, sampling_params, request_id, return self._process_request(inputs, sampling_params, request_id,
lora_request, trace_headers, lora_request, trace_headers,
prompt_adapter_request) prompt_adapter_request)
def encode( def encode(
self, self,
prompt: PromptType, inputs: PromptInputs,
pooling_params: PoolingParams, pooling_params: PoolingParams,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
@ -417,7 +418,8 @@ class MQLLMEngineClient:
from the LLMEngine to the caller. from the LLMEngine to the caller.
Args: Args:
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType` inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input. for more details about the format of each input.
pooling_params: The pooling parameters of the request. pooling_params: The pooling parameters of the request.
request_id: The unique id of the request. request_id: The unique id of the request.
@ -428,12 +430,12 @@ class MQLLMEngineClient:
The output `EmbeddingRequestOutput` objects from the LLMEngine The output `EmbeddingRequestOutput` objects from the LLMEngine
for the request. for the request.
""" """
return self._process_request(prompt, pooling_params, request_id, return self._process_request(inputs, pooling_params, request_id,
lora_request, trace_headers) lora_request, trace_headers)
async def _process_request( async def _process_request(
self, self,
prompt: PromptType, inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
@ -466,7 +468,7 @@ class MQLLMEngineClient:
request_bytes = pickle.dumps( request_bytes = pickle.dumps(
RPCProcessRequest( RPCProcessRequest(
prompt=prompt, inputs=inputs,
params=params, params=params,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,

View File

@ -252,7 +252,7 @@ class MQLLMEngine:
try: try:
self.engine.add_request( self.engine.add_request(
request_id=request_id, request_id=request_id,
prompt=request.prompt, inputs=request.inputs,
params=request.params, params=request.params,
lora_request=request.lora_request, lora_request=request.lora_request,
trace_headers=request.trace_headers, trace_headers=request.trace_headers,

View File

@ -3,7 +3,7 @@ from typing import (AsyncGenerator, List, Mapping, Optional, Protocol,
from vllm.config import DecodingConfig, ModelConfig from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptType from vllm.inputs.data import PromptInputs
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
@ -35,19 +35,19 @@ class EngineClient(Protocol):
def generate( def generate(
self, self,
prompt: PromptType, inputs: PromptInputs,
sampling_params: SamplingParams, sampling_params: SamplingParams,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncGenerator[RequestOutput, None]: ) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.""" """Generates outputs for a request"""
... ...
def encode( def encode(
self, self,
prompt: PromptType, inputs: PromptInputs,
pooling_params: PoolingParams, pooling_params: PoolingParams,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,

View File

@ -12,7 +12,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template, apply_hf_chat_template,
apply_mistral_chat_template, apply_mistral_chat_template,
parse_chat_messages) parse_chat_messages)
from vllm.inputs import PromptType, TextPrompt, TokensPrompt from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
@ -293,8 +293,8 @@ class LLM:
@overload @overload
def generate( def generate(
self, self,
prompts: Union[PromptType, Sequence[PromptType]], inputs: Union[PromptInputs, Sequence[PromptInputs]],
/, /, # We may enable `inputs` keyword after removing the old API
*, *,
sampling_params: Optional[Union[SamplingParams, sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None, Sequence[SamplingParams]]] = None,
@ -311,7 +311,7 @@ class LLM:
) )
def generate( def generate(
self, self,
prompts: Union[Union[PromptType, Sequence[PromptType]], prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
Optional[Union[str, List[str]]]] = None, Optional[Union[str, List[str]]]] = None,
sampling_params: Optional[Union[SamplingParams, sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None, Sequence[SamplingParams]]] = None,
@ -329,9 +329,7 @@ class LLM:
into a single list and pass it to this method. into a single list and pass it to this method.
Args: Args:
prompts: The prompts to the LLM. You may pass a sequence of prompts inputs: A list of inputs to generate completions for.
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
sampling_params: The sampling parameters for text generation. If sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters. None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt. When it is a single value, it is applied to every prompt.
@ -357,13 +355,12 @@ class LLM:
"models (XForCausalLM, XForConditionalGeneration).") "models (XForCausalLM, XForConditionalGeneration).")
if prompt_token_ids is not None: if prompt_token_ids is not None:
parsed_prompts = self._convert_v1_inputs( inputs = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts), prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
) )
else: else:
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
prompts)
if isinstance(guided_options_request, dict): if isinstance(guided_options_request, dict):
if len(guided_options_request) > 1: if len(guided_options_request) > 1:
@ -378,7 +375,7 @@ class LLM:
sampling_params = SamplingParams() sampling_params = SamplingParams()
self._validate_and_add_requests( self._validate_and_add_requests(
prompts=parsed_prompts, inputs=inputs,
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
@ -533,9 +530,9 @@ class LLM:
conversation, mm_data = parse_chat_messages(messages, model_config, conversation, mm_data = parse_chat_messages(messages, model_config,
tokenizer) tokenizer)
prompt_data: Union[str, List[int]] prompt: Union[str, List[int]]
if isinstance(tokenizer, MistralTokenizer): if isinstance(tokenizer, MistralTokenizer):
prompt_data = apply_mistral_chat_template( prompt = apply_mistral_chat_template(
tokenizer, tokenizer,
messages=messages, messages=messages,
chat_template=chat_template, chat_template=chat_template,
@ -543,7 +540,7 @@ class LLM:
tools=tools, tools=tools,
) )
else: else:
prompt_data = apply_hf_chat_template( prompt = apply_hf_chat_template(
tokenizer, tokenizer,
conversation=conversation, conversation=conversation,
chat_template=chat_template, chat_template=chat_template,
@ -551,17 +548,17 @@ class LLM:
tools=tools, tools=tools,
) )
prompt: PromptType inputs: PromptInputs
if is_list_of(prompt_data, int): if is_list_of(prompt, int):
prompt = TokensPrompt(prompt_token_ids=prompt_data) inputs = TokensPrompt(prompt_token_ids=prompt)
else: else:
prompt = TextPrompt(prompt=prompt_data) inputs = TextPrompt(prompt=prompt)
if mm_data is not None: if mm_data is not None:
prompt["multi_modal_data"] = mm_data inputs["multi_modal_data"] = mm_data
return self.generate( return self.generate(
prompt, inputs,
sampling_params=sampling_params, sampling_params=sampling_params,
use_tqdm=use_tqdm, use_tqdm=use_tqdm,
lora_request=lora_request, lora_request=lora_request,
@ -631,8 +628,8 @@ class LLM:
@overload @overload
def encode( def encode(
self, self,
prompts: Union[PromptType, Sequence[PromptType]], inputs: Union[PromptInputs, Sequence[PromptInputs]],
/, /, # We may enable `inputs` keyword after removing the old API
*, *,
pooling_params: Optional[Union[PoolingParams, pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None, Sequence[PoolingParams]]] = None,
@ -649,7 +646,7 @@ class LLM:
) )
def encode( def encode(
self, self,
prompts: Union[Union[PromptType, Sequence[PromptType]], prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
Optional[Union[str, List[str]]]] = None, Optional[Union[str, List[str]]]] = None,
pooling_params: Optional[Union[PoolingParams, pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None, Sequence[PoolingParams]]] = None,
@ -665,9 +662,9 @@ class LLM:
into a single list and pass it to this method. into a single list and pass it to this method.
Args: Args:
prompts: The prompts to the LLM. You may pass a sequence of prompts inputs: The inputs to the LLM. You may pass a sequence of inputs for
for batch inference. See :class:`~vllm.inputs.PromptType` batch inference. See :class:`~vllm.inputs.PromptInputs`
for more details about the format of each prompts. for more details about the format of each input.
pooling_params: The pooling parameters for pooling. If None, we pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters. use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar. use_tqdm: Whether to use tqdm to display the progress bar.
@ -690,20 +687,19 @@ class LLM:
) )
if prompt_token_ids is not None: if prompt_token_ids is not None:
parsed_prompts = self._convert_v1_inputs( inputs = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts), prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
) )
else: else:
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]], inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
prompts)
if pooling_params is None: if pooling_params is None:
# Use default pooling params. # Use default pooling params.
pooling_params = PoolingParams() pooling_params = PoolingParams()
self._validate_and_add_requests( self._validate_and_add_requests(
prompts=parsed_prompts, inputs=inputs,
params=pooling_params, params=pooling_params,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
@ -747,9 +743,9 @@ class LLM:
raise ValueError("Either prompts or prompt_token_ids must be " raise ValueError("Either prompts or prompt_token_ids must be "
"provided.") "provided.")
parsed_prompts: List[PromptType] = [] inputs: List[PromptInputs] = []
for i in range(num_requests): for i in range(num_requests):
item: PromptType item: PromptInputs
if prompts is not None: if prompts is not None:
item = TextPrompt(prompt=prompts[i]) item = TextPrompt(prompt=prompts[i])
@ -758,24 +754,24 @@ class LLM:
else: else:
raise AssertionError raise AssertionError
parsed_prompts.append(item) inputs.append(item)
return parsed_prompts return inputs
def _validate_and_add_requests( def _validate_and_add_requests(
self, self,
prompts: Union[PromptType, Sequence[PromptType]], inputs: Union[PromptInputs, Sequence[PromptInputs]],
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]], Sequence[PoolingParams]],
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
prompt_adapter_request: Optional[PromptAdapterRequest], prompt_adapter_request: Optional[PromptAdapterRequest],
guided_options: Optional[GuidedDecodingRequest] = None, guided_options: Optional[GuidedDecodingRequest] = None,
) -> None: ) -> None:
if isinstance(prompts, (str, dict)): if isinstance(inputs, (str, dict)):
# Convert a single prompt to a list. # Convert a single prompt to a list.
prompts = [prompts] inputs = [inputs]
num_requests = len(prompts) num_requests = len(inputs)
if isinstance(params, list) and len(params) != num_requests: if isinstance(params, list) and len(params) != num_requests:
raise ValueError("The lengths of prompts and params " raise ValueError("The lengths of prompts and params "
"must be the same.") "must be the same.")
@ -792,9 +788,9 @@ class LLM:
sp.output_kind = RequestOutputKind.FINAL_ONLY sp.output_kind = RequestOutputKind.FINAL_ONLY
# Add requests to the engine. # Add requests to the engine.
for i, prompt in enumerate(prompts): for i, request_inputs in enumerate(inputs):
self._add_request( self._add_request(
prompt, request_inputs,
params[i] if isinstance(params, Sequence) else params, params[i] if isinstance(params, Sequence) else params,
lora_request=lora_request[i] if isinstance( lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request, lora_request, Sequence) else lora_request,
@ -803,7 +799,7 @@ class LLM:
def _add_request( def _add_request(
self, self,
prompt: PromptType, inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@ -811,7 +807,7 @@ class LLM:
request_id = str(next(self.request_counter)) request_id = str(next(self.request_counter))
self.llm_engine.add_request( self.llm_engine.add_request(
request_id, request_id,
prompt, inputs,
params, params,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,

View File

@ -1,5 +1,5 @@
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
LLMInputs, PromptType, SingletonPrompt, TextPrompt, LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt,
TokensPrompt, build_explicit_enc_dec_prompt, TokensPrompt, build_explicit_enc_dec_prompt,
to_enc_dec_tuple_list, zip_enc_dec_prompts) to_enc_dec_tuple_list, zip_enc_dec_prompts)
from .registry import InputContext, InputRegistry from .registry import InputContext, InputRegistry
@ -16,8 +16,8 @@ See also:
__all__ = [ __all__ = [
"TextPrompt", "TextPrompt",
"TokensPrompt", "TokensPrompt",
"PromptType", "PromptInputs",
"SingletonPrompt", "SingletonPromptInputs",
"ExplicitEncoderDecoderPrompt", "ExplicitEncoderDecoderPrompt",
"LLMInputs", "LLMInputs",
"EncoderDecoderLLMInputs", "EncoderDecoderLLMInputs",

View File

@ -33,7 +33,7 @@ class TokensPrompt(TypedDict):
""" """
SingletonPrompt = Union[str, TextPrompt, TokensPrompt] SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt]
""" """
Set of possible schemas for a single LLM input: Set of possible schemas for a single LLM input:
@ -46,7 +46,7 @@ which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder the user desires to express both the encoder & decoder
prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt` prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt`
A prompt of type :class:`SingletonPromptType` may be employed A prompt of type :class:`SingletonPromptInputs` may be employed
as (1) input to a decoder-only model, (2) input to as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or where the decoder-prompt is not specified explicitly, or
@ -55,12 +55,12 @@ more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt`
""" """
_T1_co = TypeVar("_T1_co", _T1_co = TypeVar("_T1_co",
bound=SingletonPrompt, bound=SingletonPromptInputs,
default=SingletonPrompt, default=SingletonPromptInputs,
covariant=True) covariant=True)
_T2_co = TypeVar("_T2_co", _T2_co = TypeVar("_T2_co",
bound=SingletonPrompt, bound=SingletonPromptInputs,
default=SingletonPrompt, default=SingletonPromptInputs,
covariant=True) covariant=True)
@ -72,7 +72,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
The encoder and decoder prompts, respectively, The encoder and decoder prompts, respectively,
may formatted according to any of the may formatted according to any of the
:class:`SingletonPromptType` schemas, and are not :class:`SingletonPromptInputs` schemas, and are not
required to have the same schema. required to have the same schema.
Only the encoder prompt may have multi-modal data. Only the encoder prompt may have multi-modal data.
@ -81,7 +81,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
be used as an input to a decoder-only model, be used as an input to a decoder-only model,
and that the `encoder_prompt` and `decoder_prompt` and that the `encoder_prompt` and `decoder_prompt`
fields of this data structure themselves must be fields of this data structure themselves must be
:class:`SingletonPromptType` instances. :class:`SingletonPromptInputs` instances.
""" """
encoder_prompt: _T1_co encoder_prompt: _T1_co
@ -89,7 +89,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
decoder_prompt: Optional[_T2_co] decoder_prompt: Optional[_T2_co]
PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt] PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt]
""" """
Set of possible schemas for an LLM input, including Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types: both decoder-only and encoder/decoder input types:
@ -140,8 +140,12 @@ class EncoderDecoderLLMInputs(LLMInputs):
""" """
_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt) _T1 = TypeVar("_T1",
_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt) bound=SingletonPromptInputs,
default=SingletonPromptInputs)
_T2 = TypeVar("_T2",
bound=SingletonPromptInputs,
default=SingletonPromptInputs)
def build_explicit_enc_dec_prompt( def build_explicit_enc_dec_prompt(

View File

@ -5,7 +5,7 @@ from typing_extensions import TypeIs
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt, from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
LLMInputs, PromptType, SingletonPrompt, TextPrompt, LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt,
TokensPrompt) TokensPrompt)
@ -81,23 +81,23 @@ class ParsedTokensPrompt(TypedDict):
def parse_singleton_prompt( def parse_singleton_prompt(
prompt: SingletonPrompt, inputs: SingletonPromptInputs,
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]: ) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]:
if isinstance(prompt, str): if isinstance(inputs, str):
return ParsedStrPrompt(type="str", content=prompt) return ParsedStrPrompt(type="str", content=inputs)
elif isinstance(prompt, dict): elif isinstance(inputs, dict):
if "prompt_token_ids" in prompt: if "prompt_token_ids" in inputs:
return ParsedTokensPrompt(type="tokens", return ParsedTokensPrompt(type="tokens",
content=prompt) # type: ignore content=inputs) # type: ignore
elif "prompt" in prompt: elif "prompt" in inputs:
return ParsedTextPrompt(type="text", content=prompt) return ParsedTextPrompt(type="text", content=inputs)
raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt") raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt")
def is_explicit_encoder_decoder_prompt( def is_explicit_encoder_decoder_prompt(
prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]: inputs: PromptInputs) -> TypeIs[ExplicitEncoderDecoderPrompt]:
return isinstance(prompt, dict) and "encoder_prompt" in prompt return isinstance(inputs, dict) and "encoder_prompt" in inputs
def is_valid_encoder_decoder_llm_inputs( def is_valid_encoder_decoder_llm_inputs(

View File

@ -9,8 +9,8 @@ from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptType, from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
SingletonPrompt) SingletonPromptInputs)
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
if TYPE_CHECKING: if TYPE_CHECKING:
@ -206,7 +206,7 @@ class InputPreprocessor:
def _extract_prompt_components( def _extract_prompt_components(
self, self,
prompt: SingletonPrompt, inputs: SingletonPromptInputs,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
) -> PromptComponents: ) -> PromptComponents:
@ -216,7 +216,7 @@ class InputPreprocessor:
Arguments: Arguments:
* request_id * request_id
* prompt: single encoder or decoder input prompt * inputs: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts * lora_request: this is only valid for decoder prompts
Returns: Returns:
@ -226,24 +226,24 @@ class InputPreprocessor:
* multi_modal_data * multi_modal_data
''' '''
parsed = parse_singleton_prompt(prompt) parsed = parse_singleton_prompt(inputs)
if parsed["type"] == "str": if parsed["type"] == "str":
prompt_text = parsed["content"] prompt = parsed["content"]
prompt_token_ids = self._tokenize_prompt( prompt_token_ids = self._tokenize_prompt(
prompt_text, prompt,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
multi_modal_data = None multi_modal_data = None
elif parsed["type"] == "tokens": elif parsed["type"] == "tokens":
prompt_text = None prompt = None
prompt_token_ids = parsed["content"]["prompt_token_ids"] prompt_token_ids = parsed["content"]["prompt_token_ids"]
multi_modal_data = parsed["content"].get("multi_modal_data") multi_modal_data = parsed["content"].get("multi_modal_data")
elif parsed["type"] == "text": elif parsed["type"] == "text":
prompt_text = parsed["content"]["prompt"] prompt = parsed["content"]["prompt"]
prompt_token_ids = self._tokenize_prompt( prompt_token_ids = self._tokenize_prompt(
prompt_text, prompt,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
@ -251,33 +251,33 @@ class InputPreprocessor:
else: else:
assert_never(parsed) assert_never(parsed)
return prompt_text, prompt_token_ids, multi_modal_data return prompt, prompt_token_ids, multi_modal_data
async def _extract_prompt_components_async( async def _extract_prompt_components_async(
self, self,
prompt: SingletonPrompt, inputs: SingletonPromptInputs,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
) -> PromptComponents: ) -> PromptComponents:
"""Async version of :meth:`_extract_prompt_components`.""" """Async version of :meth:`_extract_prompt_components`."""
parsed = parse_singleton_prompt(prompt) parsed = parse_singleton_prompt(inputs)
if parsed["type"] == "str": if parsed["type"] == "str":
prompt_text = parsed["content"] prompt = parsed["content"]
prompt_token_ids = await self._tokenize_prompt_async( prompt_token_ids = await self._tokenize_prompt_async(
prompt_text, prompt,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
multi_modal_data = None multi_modal_data = None
elif parsed["type"] == "tokens": elif parsed["type"] == "tokens":
prompt_text = None prompt = None
prompt_token_ids = parsed["content"]["prompt_token_ids"] prompt_token_ids = parsed["content"]["prompt_token_ids"]
multi_modal_data = parsed["content"].get("multi_modal_data") multi_modal_data = parsed["content"].get("multi_modal_data")
elif parsed["type"] == "text": elif parsed["type"] == "text":
prompt_text = parsed["content"]["prompt"] prompt = parsed["content"]["prompt"]
prompt_token_ids = await self._tokenize_prompt_async( prompt_token_ids = await self._tokenize_prompt_async(
prompt_text, prompt,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
@ -285,7 +285,7 @@ class InputPreprocessor:
else: else:
assert_never(parsed) assert_never(parsed)
return prompt_text, prompt_token_ids, multi_modal_data return prompt, prompt_token_ids, multi_modal_data
def _build_enc_dec_llm_inputs( def _build_enc_dec_llm_inputs(
self, self,
@ -311,7 +311,7 @@ class InputPreprocessor:
def _process_encoder_decoder_prompt( def _process_encoder_decoder_prompt(
self, self,
prompt: PromptType, inputs: PromptInputs,
request_id: str, request_id: str,
) -> EncoderDecoderLLMInputs: ) -> EncoderDecoderLLMInputs:
''' '''
@ -339,7 +339,7 @@ class InputPreprocessor:
Arguments: Arguments:
* prompt: an input prompt * inputs: an input prompt
* request_id * request_id
Returns: Returns:
@ -350,13 +350,13 @@ class InputPreprocessor:
encoder_comps: PromptComponents encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents decoder_comps: DecoderPromptComponents
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(inputs):
encoder_comps = self._extract_prompt_components( encoder_comps = self._extract_prompt_components(
prompt["encoder_prompt"], inputs["encoder_prompt"],
request_id=request_id, request_id=request_id,
) )
if (decoder_input := prompt["decoder_prompt"]) is None: if (decoder_input := inputs["decoder_prompt"]) is None:
decoder_comps = None, None, None decoder_comps = None, None, None
else: else:
decoder_comps = self._extract_prompt_components( decoder_comps = self._extract_prompt_components(
@ -365,7 +365,7 @@ class InputPreprocessor:
) )
else: else:
encoder_comps = self._extract_prompt_components( encoder_comps = self._extract_prompt_components(
prompt, inputs,
request_id=request_id, request_id=request_id,
) )
@ -375,20 +375,20 @@ class InputPreprocessor:
async def _process_encoder_decoder_prompt_async( async def _process_encoder_decoder_prompt_async(
self, self,
prompt: PromptType, inputs: PromptInputs,
request_id: str, request_id: str,
) -> EncoderDecoderLLMInputs: ) -> EncoderDecoderLLMInputs:
"""Async version of :meth:`_process_encoder_decoder_prompt`.""" """Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_comps: PromptComponents encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents decoder_comps: DecoderPromptComponents
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(inputs):
encoder_task = self._extract_prompt_components_async( encoder_task = self._extract_prompt_components_async(
prompt["encoder_prompt"], inputs["encoder_prompt"],
request_id=request_id, request_id=request_id,
) )
if (decoder_input := prompt["decoder_prompt"]) is None: if (decoder_input := inputs["decoder_prompt"]) is None:
encoder_comps = await encoder_task encoder_comps = await encoder_task
decoder_comps = None, None, None decoder_comps = None, None, None
else: else:
@ -401,7 +401,7 @@ class InputPreprocessor:
encoder_task, decoder_task) encoder_task, decoder_task)
else: else:
encoder_comps = await self._extract_prompt_components_async( encoder_comps = await self._extract_prompt_components_async(
prompt, inputs,
request_id=request_id, request_id=request_id,
) )
@ -425,7 +425,7 @@ class InputPreprocessor:
def _process_decoder_only_prompt( def _process_decoder_only_prompt(
self, self,
prompt: SingletonPrompt, inputs: SingletonPromptInputs,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@ -436,7 +436,7 @@ class InputPreprocessor:
Arguments: Arguments:
* prompt: input prompt * inputs: input prompt
* request_id * request_id
* lora_request * lora_request
* prompt_adapter_request * prompt_adapter_request
@ -447,7 +447,7 @@ class InputPreprocessor:
''' '''
prompt_comps = self._extract_prompt_components( prompt_comps = self._extract_prompt_components(
prompt, inputs,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
@ -459,14 +459,14 @@ class InputPreprocessor:
async def _process_decoder_only_prompt_async( async def _process_decoder_only_prompt_async(
self, self,
prompt: SingletonPrompt, inputs: SingletonPromptInputs,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs: ) -> LLMInputs:
"""Async version of :meth:`_process_decoder_only_prompt`.""" """Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps = await self._extract_prompt_components_async( prompt_comps = await self._extract_prompt_components_async(
prompt, inputs,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
@ -478,7 +478,7 @@ class InputPreprocessor:
def preprocess( def preprocess(
self, self,
prompt: PromptType, inputs: PromptInputs,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@ -488,17 +488,17 @@ class InputPreprocessor:
# Encoder-decoder model requires special mapping of # Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder # input prompts to encoder & decoder
return self._process_encoder_decoder_prompt( return self._process_encoder_decoder_prompt(
prompt, inputs,
request_id=request_id, request_id=request_id,
) )
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(inputs):
raise ValueError("Cannot pass encoder-decoder prompt " raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models") "to decoder-only models")
# Decoder-only operation # Decoder-only operation
return self._process_decoder_only_prompt( return self._process_decoder_only_prompt(
prompt, inputs,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
@ -506,7 +506,7 @@ class InputPreprocessor:
async def preprocess_async( async def preprocess_async(
self, self,
prompt: PromptType, inputs: PromptInputs,
request_id: str, request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@ -516,17 +516,17 @@ class InputPreprocessor:
# Encoder-decoder model requires special mapping of # Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder # input prompts to encoder & decoder
return await self._process_encoder_decoder_prompt_async( return await self._process_encoder_decoder_prompt_async(
prompt, inputs,
request_id=request_id, request_id=request_id,
) )
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(inputs):
raise ValueError("Cannot pass encoder-decoder prompt " raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models") "to decoder-only models")
# Decoder-only operation # Decoder-only operation
return await self._process_decoder_only_prompt_async( return await self._process_decoder_only_prompt_async(
prompt, inputs,
request_id=request_id, request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,