[Core] renamePromptInputs and inputs (#8876)

This commit is contained in:
Cyrus Leung 2024-09-27 11:35:15 +08:00 committed by GitHub
parent 344cd2b6f4
commit 3b00b9c26c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 440 additions and 248 deletions

View File

@ -11,7 +11,7 @@ from tqdm import tqdm
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs
from vllm.inputs import PromptInputs
from vllm.inputs import PromptType
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser
@ -61,7 +61,7 @@ def main(args: argparse.Namespace):
dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size,
args.input_len))
dummy_inputs: List[PromptInputs] = [{
dummy_prompts: List[PromptType] = [{
"prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()]
@ -74,13 +74,13 @@ def main(args: argparse.Namespace):
],
on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir))) as p:
llm.generate(dummy_inputs,
llm.generate(dummy_prompts,
sampling_params=sampling_params,
use_tqdm=False)
print(p.key_averages())
else:
start_time = time.perf_counter()
llm.generate(dummy_inputs,
llm.generate(dummy_prompts,
sampling_params=sampling_params,
use_tqdm=False)
end_time = time.perf_counter()

View File

@ -8,7 +8,7 @@ Multi-Modality
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`.
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptType`.
Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
by following :ref:`this guide <adding_multimodal_plugin>`.

View File

@ -1,7 +1,7 @@
LLM Inputs
==========
.. autodata:: vllm.inputs.PromptInputs
.. autodata:: vllm.inputs.PromptType
.. autoclass:: vllm.inputs.TextPrompt
:show-inheritance:

View File

@ -27,7 +27,7 @@ The :class:`~vllm.LLM` class can be instantiated in much the same way as languag
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model.
To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`:
To pass an image to the model, note the following in :class:`vllm.inputs.PromptType`:
* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.

View File

@ -86,17 +86,19 @@ class MockAsyncLLMEngine(AsyncLLMEngine):
@pytest.mark.asyncio
async def test_new_requests_event():
params = SamplingParams()
engine = MockAsyncLLMEngine()
engine.start_background_loop()
await asyncio.sleep(0.01)
assert engine.engine.step_calls == 0
await engine.add_request("1", "", None)
await engine.add_request("1", "", params)
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 1
assert engine.engine.step_calls == 1
await engine.add_request("2", "", None)
await engine.add_request("2", "", params)
engine.engine.generate("2")
await asyncio.sleep(0)
await asyncio.sleep(0)
@ -111,7 +113,7 @@ async def test_new_requests_event():
await asyncio.sleep(0.001)
assert engine.engine.step_calls == old_step_calls
await engine.add_request("3", "", None)
await engine.add_request("3", "", params)
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == old_step_calls + 1

View File

@ -49,21 +49,6 @@ def assert_outputs_equal(o1: List[EmbeddingRequestOutput],
assert [o.outputs for o in o1] == [o.outputs for o in o2]
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt', PROMPTS)
def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
pooling_params = PoolingParams()
with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.encode(prompts=prompt, pooling_params=pooling_params)
v2_output = llm.encode(prompt, pooling_params=pooling_params)
assert_outputs_equal(v1_output, v2_output)
v2_output = llm.encode({"prompt": prompt}, pooling_params=pooling_params)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
@ -79,25 +64,6 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
pooling_params = PoolingParams()
with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.encode(prompts=PROMPTS, pooling_params=pooling_params)
v2_output = llm.encode(PROMPTS, pooling_params=pooling_params)
assert_outputs_equal(v1_output, v2_output)
v2_output = llm.encode(
[{
"prompt": p
} for p in PROMPTS],
pooling_params=pooling_params,
)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
pooling_params = PoolingParams()

View File

@ -47,23 +47,6 @@ def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]):
assert [o.outputs for o in o1] == [o.outputs for o in o2]
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt', PROMPTS)
def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.generate(prompts=prompt,
sampling_params=sampling_params)
v2_output = llm.generate(prompt, sampling_params=sampling_params)
assert_outputs_equal(v1_output, v2_output)
v2_output = llm.generate({"prompt": prompt},
sampling_params=sampling_params)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
@ -79,26 +62,6 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
with pytest.warns(DeprecationWarning, match="'prompts'"):
v1_output = llm.generate(prompts=PROMPTS,
sampling_params=sampling_params)
v2_output = llm.generate(PROMPTS, sampling_params=sampling_params)
assert_outputs_equal(v1_output, v2_output)
v2_output = llm.generate(
[{
"prompt": p
} for p in PROMPTS],
sampling_params=sampling_params,
)
assert_outputs_equal(v1_output, v2_output)
@pytest.mark.skip_global_cleanup
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)

View File

@ -61,7 +61,7 @@ async def test_evil_forward(tmp_socket):
# Throws an error in first forward pass.
with pytest.raises(RAISED_ERROR):
async for _ in client.generate(inputs="Hello my name is",
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
@ -69,7 +69,7 @@ async def test_evil_forward(tmp_socket):
# Engine is errored, should get ENGINE_DEAD_ERROR.
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(inputs="Hello my name is",
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
@ -118,7 +118,7 @@ async def test_failed_health_check(tmp_socket):
# Generate call should throw ENGINE_DEAD_ERROR
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(inputs="Hello my name is",
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
@ -160,7 +160,7 @@ async def test_failed_abort(tmp_socket):
# with reference to the original KeyError("foo")
with pytest.raises(MQEngineDeadError) as execinfo:
async for _ in client.generate(
inputs="Hello my name is",
prompt="Hello my name is",
sampling_params=SamplingParams(max_tokens=10),
request_id=uuid.uuid4()):
pass
@ -183,7 +183,7 @@ async def test_bad_request(tmp_socket):
# Invalid request should fail, but not crash the server.
with pytest.raises(ValueError):
async for _ in client.generate(inputs="Hello my name is",
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id="abcd-1",
lora_request=LoRARequest(
@ -192,7 +192,7 @@ async def test_bad_request(tmp_socket):
pass
# This request should be okay.
async for _ in client.generate(inputs="Hello my name is",
async for _ in client.generate(prompt="Hello my name is",
sampling_params=SamplingParams(),
request_id="abcd-2"):
pass

View File

@ -20,7 +20,7 @@ async def generate(
count = 0
async for out in client.generate(
request_id=request_id,
inputs="Hello my name is Robert and",
prompt="Hello my name is Robert and",
sampling_params=SamplingParams(max_tokens=num_tokens,
temperature=0)):

View File

@ -5,7 +5,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput, RequestOutput)
@ -19,7 +19,7 @@ __all__ = [
"__version_tuple__",
"LLM",
"ModelRegistry",
"PromptInputs",
"PromptType",
"TextPrompt",
"TokensPrompt",
"SamplingParams",

View File

@ -2,8 +2,8 @@ import asyncio
import time
import weakref
from functools import partial
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
Mapping, Optional, Set, Tuple, Type, Union)
from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable,
List, Mapping, Optional, Set, Tuple, Type, Union, overload)
from weakref import ReferenceType
import vllm.envs as envs
@ -17,7 +17,7 @@ from vllm.engine.metrics_types import StatLoggerBase
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutorAsync
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptInputs
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
@ -28,7 +28,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import weak_bind
from vllm.utils import deprecate_kwargs, weak_bind
logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
@ -402,17 +402,54 @@ class _AsyncLLMEngine(LLMEngine):
"""Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async()
@overload # DEPRECATED
async def add_request_async(
self,
request_id: str,
inputs: PromptInputs,
*,
inputs: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
...
@overload
async def add_request_async(
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
...
@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
async def add_request_async(
self,
request_id: str,
prompt: Optional[PromptType] = None,
params: Optional[Union[SamplingParams, PoolingParams]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
"""Async version of :meth:`add_request`."""
if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
@ -420,7 +457,7 @@ class _AsyncLLMEngine(LLMEngine):
arrival_time = time.time()
preprocessed_inputs = await self.input_preprocessor.preprocess_async(
inputs,
prompt,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
@ -774,16 +811,55 @@ class AsyncLLMEngine:
# This method does not need to be async, but kept that way
# for backwards compatibility.
async def add_request(
@overload # DEPRECATED
def add_request(
self,
request_id: str,
inputs: PromptInputs,
*,
inputs: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, EmbeddingRequestOutput], None]]:
...
@overload
def add_request(
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> Coroutine[None, None, AsyncGenerator[Union[
RequestOutput, EmbeddingRequestOutput], None]]:
...
@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
async def add_request(
self,
request_id: str,
prompt: Optional[PromptType] = None,
params: Optional[Union[SamplingParams, PoolingParams]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None
if not self.is_running:
if self.start_engine_loop:
self.start_background_loop()
@ -797,7 +873,7 @@ class AsyncLLMEngine:
stream = self._request_tracker.add_request(
request_id,
verbose=self.log_requests,
inputs=inputs,
prompt=prompt,
params=params,
arrival_time=arrival_time or time.time(),
lora_request=lora_request,
@ -808,7 +884,7 @@ class AsyncLLMEngine:
async def generate(
self,
inputs: PromptInputs,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
@ -822,8 +898,7 @@ class AsyncLLMEngine:
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
@ -881,7 +956,7 @@ class AsyncLLMEngine:
"""
async for output in await self.add_request(
request_id,
inputs,
prompt,
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
@ -891,7 +966,7 @@ class AsyncLLMEngine:
async def encode(
self,
inputs: PromptInputs,
prompt: PromptType,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
@ -904,8 +979,7 @@ class AsyncLLMEngine:
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
@ -959,7 +1033,7 @@ class AsyncLLMEngine:
"""
async for output in await self.add_request(
request_id,
inputs,
prompt,
pooling_params,
lora_request=lora_request,
trace_headers=trace_headers,

View File

@ -6,7 +6,7 @@ from functools import partial
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Type, Union
from typing import Set, Type, Union, overload
import torch
from typing_extensions import TypeVar
@ -29,7 +29,7 @@ from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
InputRegistry, LLMInputs, PromptInputs)
InputRegistry, LLMInputs, PromptType)
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -51,7 +51,7 @@ from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter, Device, weak_bind
from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
@ -689,16 +689,51 @@ class LLMEngine:
def stop_remote_worker_execution_loop(self) -> None:
self.model_executor.stop_remote_worker_execution_loop()
@overload # DEPRECATED
def add_request(
self,
request_id: str,
inputs: PromptInputs,
*,
inputs: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
...
@overload
def add_request(
self,
request_id: str,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None:
...
@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
def add_request(
self,
request_id: str,
prompt: Optional[PromptType] = None,
params: Optional[Union[SamplingParams, PoolingParams]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
"""Add a request to the engine's request pool.
@ -708,8 +743,7 @@ class LLMEngine:
Args:
request_id: The unique ID of the request.
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input.
params: Parameters for sampling or pooling.
:class:`~vllm.SamplingParams` for text generation.
@ -744,6 +778,10 @@ class LLMEngine:
>>> # continue the request processing
>>> ...
"""
if inputs is not None:
prompt = inputs
assert prompt is not None and params is not None
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
@ -756,7 +794,7 @@ class LLMEngine:
arrival_time = time.time()
preprocessed_inputs = self.input_preprocessor.preprocess(
inputs,
prompt,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,

View File

@ -1,13 +1,14 @@
from dataclasses import dataclass
from enum import Enum
from typing import List, Mapping, Optional, Union
from typing import List, Mapping, Optional, Union, overload
from vllm import PoolingParams
from vllm.inputs import PromptInputs
from vllm.inputs import PromptType
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.utils import deprecate_kwargs
VLLM_RPC_SUCCESS_STR = "SUCCESS"
@ -23,13 +24,67 @@ class MQEngineDeadError(RuntimeError):
@dataclass
class RPCProcessRequest:
inputs: PromptInputs
prompt: PromptType
params: Union[SamplingParams, PoolingParams]
request_id: str
lora_request: Optional[LoRARequest] = None
trace_headers: Optional[Mapping[str, str]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
@overload # DEPRECATED
def __init__(
self,
*,
inputs: PromptType,
params: Union[SamplingParams, PoolingParams],
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
...
@overload
def __init__(
self,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
...
@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
def __init__(
self,
prompt: Optional[PromptType] = None,
params: Optional[Union[SamplingParams, PoolingParams]] = None,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
if inputs is not None:
prompt = inputs
assert (prompt is not None and params is not None
and request_id is not None)
super().__init__()
self.prompt = prompt
self.params = params
self.request_id = request_id
self.lora_request = lora_request
self.trace_headers = trace_headers
self.prompt_adapter_request = prompt_adapter_request
@dataclass
class RPCError:

View File

@ -3,7 +3,7 @@ import copy
import pickle
from contextlib import contextmanager, suppress
from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional,
Union)
Union, overload)
import cloudpickle
import zmq
@ -25,13 +25,14 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCUProfileRequest)
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.inputs import PromptInputs
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import deprecate_kwargs
logger = init_logger(__name__)
@ -367,14 +368,45 @@ class MQLLMEngineClient:
def dead_error(self) -> BaseException:
return ENGINE_DEAD_ERROR(self._errored_with)
@overload # DEPRECATED
def generate(
self,
inputs: PromptInputs,
*,
inputs: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> AsyncGenerator[RequestOutput, None]:
...
@overload
def generate(
self,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> AsyncGenerator[RequestOutput, None]:
...
@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
def generate(
self,
prompt: Optional[PromptType] = None,
sampling_params: Optional[SamplingParams] = None,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
*,
inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request.
@ -383,8 +415,7 @@ class MQLLMEngineClient:
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
@ -393,17 +424,51 @@ class MQLLMEngineClient:
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
"""
return self._process_request(inputs, sampling_params, request_id,
if inputs is not None:
prompt = inputs
assert (prompt is not None and sampling_params is not None
and request_id is not None)
return self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers,
prompt_adapter_request)
@overload # DEPRECATED
def encode(
self,
inputs: PromptInputs,
*,
inputs: PromptType,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
...
@overload
def encode(
self,
prompt: PromptType,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
...
@deprecate_kwargs(
"inputs",
additional_message="Please use the 'prompt' parameter instead.",
)
def encode(
self,
prompt: Optional[PromptType] = None,
pooling_params: Optional[PoolingParams] = None,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
*,
inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
"""Generate outputs for a request from an embedding model.
@ -412,8 +477,7 @@ class MQLLMEngineClient:
from the LLMEngine to the caller.
Args:
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
@ -424,12 +488,17 @@ class MQLLMEngineClient:
The output `EmbeddingRequestOutput` objects from the LLMEngine
for the request.
"""
return self._process_request(inputs, pooling_params, request_id,
if inputs is not None:
prompt = inputs
assert (prompt is not None and pooling_params is not None
and request_id is not None)
return self._process_request(prompt, pooling_params, request_id,
lora_request, trace_headers)
async def _process_request(
self,
inputs: PromptInputs,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
request_id: str,
lora_request: Optional[LoRARequest] = None,
@ -462,7 +531,7 @@ class MQLLMEngineClient:
request_bytes = pickle.dumps(
RPCProcessRequest(
inputs=inputs,
prompt=prompt,
params=params,
request_id=request_id,
lora_request=lora_request,

View File

@ -278,7 +278,7 @@ class MQLLMEngine:
try:
self.engine.add_request(
request_id=request_id,
inputs=request.inputs,
prompt=request.prompt,
params=request.params,
lora_request=request.lora_request,
trace_headers=request.trace_headers,

View File

@ -3,7 +3,7 @@ from typing import (AsyncGenerator, List, Mapping, Optional, Protocol,
from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptInputs
from vllm.inputs.data import PromptType
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
@ -35,19 +35,19 @@ class EngineClient(Protocol):
def generate(
self,
inputs: PromptInputs,
prompt: PromptType,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncGenerator[RequestOutput, None]:
"""Generates outputs for a request"""
"""Generate outputs for a request."""
...
def encode(
self,
inputs: PromptInputs,
prompt: PromptType,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,

View File

@ -12,7 +12,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages)
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -293,8 +293,8 @@ class LLM:
@overload
def generate(
self,
inputs: Union[PromptInputs, Sequence[PromptInputs]],
/, # We may enable `inputs` keyword after removing the old API
prompts: Union[PromptType, Sequence[PromptType]],
/,
*,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
@ -304,14 +304,13 @@ class LLM:
...
@deprecate_kwargs(
"prompts",
"prompt_token_ids",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter instead.",
additional_message="Please use the 'prompts' parameter instead.",
)
def generate(
self,
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
prompts: Union[Union[PromptType, Sequence[PromptType]],
Optional[Union[str, List[str]]]] = None,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
@ -330,7 +329,9 @@ class LLM:
into a single list and pass it to this method.
Args:
inputs: A list of inputs to generate completions for.
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt.
@ -358,12 +359,13 @@ class LLM:
"models (XForCausalLM, XForConditionalGeneration).")
if prompt_token_ids is not None:
inputs = self._convert_v1_inputs(
parsed_prompts = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids,
)
else:
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
prompts)
if isinstance(guided_options_request, dict):
if len(guided_options_request) > 1:
@ -378,7 +380,7 @@ class LLM:
sampling_params = SamplingParams()
self._validate_and_add_requests(
inputs=inputs,
prompts=parsed_prompts,
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
@ -648,8 +650,8 @@ class LLM:
@overload
def encode(
self,
inputs: Union[PromptInputs, Sequence[PromptInputs]],
/, # We may enable `inputs` keyword after removing the old API
prompts: Union[PromptType, Sequence[PromptType]],
/,
*,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
@ -659,14 +661,13 @@ class LLM:
...
@deprecate_kwargs(
"prompts",
"prompt_token_ids",
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
additional_message="Please use the 'inputs' parameter instead.",
additional_message="Please use the 'prompts' parameter instead.",
)
def encode(
self,
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
prompts: Union[Union[PromptType, Sequence[PromptType]],
Optional[Union[str, List[str]]]] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
@ -682,9 +683,9 @@ class LLM:
into a single list and pass it to this method.
Args:
inputs: The inputs to the LLM. You may pass a sequence of inputs for
batch inference. See :class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar.
@ -707,19 +708,20 @@ class LLM:
)
if prompt_token_ids is not None:
inputs = self._convert_v1_inputs(
parsed_prompts = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids,
)
else:
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
prompts)
if pooling_params is None:
# Use default pooling params.
pooling_params = PoolingParams()
self._validate_and_add_requests(
inputs=inputs,
prompts=parsed_prompts,
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
@ -763,9 +765,9 @@ class LLM:
raise ValueError("Either prompts or prompt_token_ids must be "
"provided.")
inputs: List[PromptInputs] = []
parsed_prompts: List[PromptType] = []
for i in range(num_requests):
item: PromptInputs
item: PromptType
if prompts is not None:
item = TextPrompt(prompt=prompts[i])
@ -774,13 +776,13 @@ class LLM:
else:
raise AssertionError
inputs.append(item)
parsed_prompts.append(item)
return inputs
return parsed_prompts
def _validate_and_add_requests(
self,
inputs: Union[PromptInputs, Sequence[PromptInputs]],
prompts: Union[PromptType, Sequence[PromptType]],
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]],
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
@ -788,11 +790,11 @@ class LLM:
guided_options: Optional[GuidedDecodingRequest] = None,
priority: Optional[List[int]] = None,
) -> None:
if isinstance(inputs, (str, dict)):
if isinstance(prompts, (str, dict)):
# Convert a single prompt to a list.
inputs = [inputs]
prompts = [prompts]
num_requests = len(inputs)
num_requests = len(prompts)
if isinstance(params, list) and len(params) != num_requests:
raise ValueError("The lengths of prompts and params "
"must be the same.")
@ -809,9 +811,9 @@ class LLM:
sp.output_kind = RequestOutputKind.FINAL_ONLY
# Add requests to the engine.
for i, request_inputs in enumerate(inputs):
for i, prompt in enumerate(prompts):
self._add_request(
request_inputs,
prompt,
params[i] if isinstance(params, Sequence) else params,
lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request,
@ -821,7 +823,7 @@ class LLM:
def _add_request(
self,
inputs: PromptInputs,
prompt: PromptType,
params: Union[SamplingParams, PoolingParams],
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@ -830,7 +832,7 @@ class LLM:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(
request_id,
inputs,
prompt,
params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,

View File

@ -1,5 +1,5 @@
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt,
LLMInputs, PromptType, SingletonPrompt, TextPrompt,
TokensPrompt, build_explicit_enc_dec_prompt,
to_enc_dec_tuple_list, zip_enc_dec_prompts)
from .registry import InputContext, InputRegistry
@ -16,8 +16,8 @@ See also:
__all__ = [
"TextPrompt",
"TokensPrompt",
"PromptInputs",
"SingletonPromptInputs",
"PromptType",
"SingletonPrompt",
"ExplicitEncoderDecoderPrompt",
"LLMInputs",
"EncoderDecoderLLMInputs",
@ -28,3 +28,17 @@ __all__ = [
"InputContext",
"InputRegistry",
]
def __getattr__(name: str):
if name == "PromptInput":
import warnings
msg = ("PromptInput has been renamed to PromptType. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return PromptType
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -33,7 +33,7 @@ class TokensPrompt(TypedDict):
"""
SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt]
SingletonPrompt = Union[str, TextPrompt, TokensPrompt]
"""
Set of possible schemas for a single LLM input:
@ -46,7 +46,7 @@ which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt`
A prompt of type :class:`SingletonPromptInputs` may be employed
A prompt of type :class:`SingletonPrompt` may be employed
as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
@ -55,33 +55,32 @@ more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt`
"""
_T1_co = TypeVar("_T1_co",
bound=SingletonPromptInputs,
default=SingletonPromptInputs,
bound=SingletonPrompt,
default=SingletonPrompt,
covariant=True)
_T2_co = TypeVar("_T2_co",
bound=SingletonPromptInputs,
default=SingletonPromptInputs,
bound=SingletonPrompt,
default=SingletonPrompt,
covariant=True)
# TODO: Make fields ReadOnly once mypy supports it
class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
"""Represents an encoder/decoder model input prompt,
comprising an explicit encoder prompt and a
decoder prompt.
"""
Represents an encoder/decoder model input prompt,
comprising an explicit encoder prompt and a decoder prompt.
The encoder and decoder prompts, respectively,
may formatted according to any of the
:class:`SingletonPromptInputs` schemas, and are not
required to have the same schema.
The encoder and decoder prompts, respectively, may be formatted
according to any of the :class:`SingletonPrompt` schemas,
and are not required to have the same schema.
Only the encoder prompt may have multi-modal data.
Note that an :class:`ExplicitEncoderDecoderPrompt` may not
be used as an input to a decoder-only model,
and that the `encoder_prompt` and `decoder_prompt`
and that the :code:`encoder_prompt` and :code:`decoder_prompt`
fields of this data structure themselves must be
:class:`SingletonPromptInputs` instances.
:class:`SingletonPrompt` instances.
"""
encoder_prompt: _T1_co
@ -89,7 +88,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
decoder_prompt: Optional[_T2_co]
PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt]
PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt]
"""
Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types:
@ -146,12 +145,8 @@ class EncoderDecoderLLMInputs(LLMInputs):
"""
_T1 = TypeVar("_T1",
bound=SingletonPromptInputs,
default=SingletonPromptInputs)
_T2 = TypeVar("_T2",
bound=SingletonPromptInputs,
default=SingletonPromptInputs)
_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt)
_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)
def build_explicit_enc_dec_prompt(
@ -182,3 +177,17 @@ def to_enc_dec_tuple_list(
return [(enc_dec_prompt["encoder_prompt"],
enc_dec_prompt["decoder_prompt"])
for enc_dec_prompt in enc_dec_prompts]
def __getattr__(name: str):
if name == "PromptInput":
import warnings
msg = ("PromptInput has been renamed to PromptType. "
"The original name will be removed in an upcoming version.")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
return PromptType
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -5,7 +5,7 @@ from typing_extensions import TypeIs
from vllm.utils import is_list_of
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt,
LLMInputs, PromptType, SingletonPrompt, TextPrompt,
TokensPrompt)
@ -81,23 +81,23 @@ class ParsedTokensPrompt(TypedDict):
def parse_singleton_prompt(
inputs: SingletonPromptInputs,
prompt: SingletonPrompt,
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]:
if isinstance(inputs, str):
return ParsedStrPrompt(type="str", content=inputs)
elif isinstance(inputs, dict):
if "prompt_token_ids" in inputs:
if isinstance(prompt, str):
return ParsedStrPrompt(type="str", content=prompt)
elif isinstance(prompt, dict):
if "prompt_token_ids" in prompt:
return ParsedTokensPrompt(type="tokens",
content=inputs) # type: ignore
elif "prompt" in inputs:
return ParsedTextPrompt(type="text", content=inputs)
content=prompt) # type: ignore
elif "prompt" in prompt:
return ParsedTextPrompt(type="text", content=prompt)
raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt")
def is_explicit_encoder_decoder_prompt(
inputs: PromptInputs) -> TypeIs[ExplicitEncoderDecoderPrompt]:
return isinstance(inputs, dict) and "encoder_prompt" in inputs
prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
return isinstance(prompt, dict) and "encoder_prompt" in prompt
def is_valid_encoder_decoder_llm_inputs(

View File

@ -10,8 +10,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.utils import print_warning_once
from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
SingletonPromptInputs)
from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptType,
SingletonPrompt)
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
if TYPE_CHECKING:
@ -209,7 +209,7 @@ class InputPreprocessor:
def _extract_prompt_components(
self,
inputs: SingletonPromptInputs,
prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> PromptComponents:
@ -219,7 +219,7 @@ class InputPreprocessor:
Arguments:
* request_id
* inputs: single encoder or decoder input prompt
* prompt: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts
Returns:
@ -229,24 +229,24 @@ class InputPreprocessor:
* multi_modal_data
'''
parsed = parse_singleton_prompt(inputs)
parsed = parse_singleton_prompt(prompt)
if parsed["type"] == "str":
prompt = parsed["content"]
prompt_text = parsed["content"]
prompt_token_ids = self._tokenize_prompt(
prompt,
prompt_text,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = None
elif parsed["type"] == "tokens":
prompt = None
prompt_text = None
prompt_token_ids = parsed["content"]["prompt_token_ids"]
multi_modal_data = parsed["content"].get("multi_modal_data")
elif parsed["type"] == "text":
prompt = parsed["content"]["prompt"]
prompt_text = parsed["content"]["prompt"]
prompt_token_ids = self._tokenize_prompt(
prompt,
prompt_text,
request_id=request_id,
lora_request=lora_request,
)
@ -254,33 +254,33 @@ class InputPreprocessor:
else:
assert_never(parsed)
return prompt, prompt_token_ids, multi_modal_data
return prompt_text, prompt_token_ids, multi_modal_data
async def _extract_prompt_components_async(
self,
inputs: SingletonPromptInputs,
prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> PromptComponents:
"""Async version of :meth:`_extract_prompt_components`."""
parsed = parse_singleton_prompt(inputs)
parsed = parse_singleton_prompt(prompt)
if parsed["type"] == "str":
prompt = parsed["content"]
prompt_text = parsed["content"]
prompt_token_ids = await self._tokenize_prompt_async(
prompt,
prompt_text,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = None
elif parsed["type"] == "tokens":
prompt = None
prompt_text = None
prompt_token_ids = parsed["content"]["prompt_token_ids"]
multi_modal_data = parsed["content"].get("multi_modal_data")
elif parsed["type"] == "text":
prompt = parsed["content"]["prompt"]
prompt_text = parsed["content"]["prompt"]
prompt_token_ids = await self._tokenize_prompt_async(
prompt,
prompt_text,
request_id=request_id,
lora_request=lora_request,
)
@ -288,7 +288,7 @@ class InputPreprocessor:
else:
assert_never(parsed)
return prompt, prompt_token_ids, multi_modal_data
return prompt_text, prompt_token_ids, multi_modal_data
def _build_enc_dec_llm_inputs(
self,
@ -321,7 +321,7 @@ class InputPreprocessor:
def _process_encoder_decoder_prompt(
self,
inputs: PromptInputs,
prompt: PromptType,
request_id: str,
) -> EncoderDecoderLLMInputs:
'''
@ -349,7 +349,7 @@ class InputPreprocessor:
Arguments:
* inputs: an input prompt
* prompt: an input prompt
* request_id
Returns:
@ -360,13 +360,13 @@ class InputPreprocessor:
encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents
if is_explicit_encoder_decoder_prompt(inputs):
if is_explicit_encoder_decoder_prompt(prompt):
encoder_comps = self._extract_prompt_components(
inputs["encoder_prompt"],
prompt["encoder_prompt"],
request_id=request_id,
)
if (decoder_input := inputs["decoder_prompt"]) is None:
if (decoder_input := prompt["decoder_prompt"]) is None:
decoder_comps = None, None, None
else:
decoder_comps = self._extract_prompt_components(
@ -375,7 +375,7 @@ class InputPreprocessor:
)
else:
encoder_comps = self._extract_prompt_components(
inputs,
prompt,
request_id=request_id,
)
@ -385,20 +385,20 @@ class InputPreprocessor:
async def _process_encoder_decoder_prompt_async(
self,
inputs: PromptInputs,
prompt: PromptType,
request_id: str,
) -> EncoderDecoderLLMInputs:
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents
if is_explicit_encoder_decoder_prompt(inputs):
if is_explicit_encoder_decoder_prompt(prompt):
encoder_task = self._extract_prompt_components_async(
inputs["encoder_prompt"],
prompt["encoder_prompt"],
request_id=request_id,
)
if (decoder_input := inputs["decoder_prompt"]) is None:
if (decoder_input := prompt["decoder_prompt"]) is None:
encoder_comps = await encoder_task
decoder_comps = None, None, None
else:
@ -411,7 +411,7 @@ class InputPreprocessor:
encoder_task, decoder_task)
else:
encoder_comps = await self._extract_prompt_components_async(
inputs,
prompt,
request_id=request_id,
)
@ -435,7 +435,7 @@ class InputPreprocessor:
def _process_decoder_only_prompt(
self,
inputs: SingletonPromptInputs,
prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@ -446,7 +446,7 @@ class InputPreprocessor:
Arguments:
* inputs: input prompt
* prompt: input prompt
* request_id
* lora_request
* prompt_adapter_request
@ -457,7 +457,7 @@ class InputPreprocessor:
'''
prompt_comps = self._extract_prompt_components(
inputs,
prompt,
request_id=request_id,
lora_request=lora_request,
)
@ -469,14 +469,14 @@ class InputPreprocessor:
async def _process_decoder_only_prompt_async(
self,
inputs: SingletonPromptInputs,
prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
"""Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps = await self._extract_prompt_components_async(
inputs,
prompt,
request_id=request_id,
lora_request=lora_request,
)
@ -488,7 +488,7 @@ class InputPreprocessor:
def preprocess(
self,
inputs: PromptInputs,
prompt: PromptType,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@ -498,17 +498,17 @@ class InputPreprocessor:
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
return self._process_encoder_decoder_prompt(
inputs,
prompt,
request_id=request_id,
)
if is_explicit_encoder_decoder_prompt(inputs):
if is_explicit_encoder_decoder_prompt(prompt):
raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models")
# Decoder-only operation
return self._process_decoder_only_prompt(
inputs,
prompt,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
@ -516,7 +516,7 @@ class InputPreprocessor:
async def preprocess_async(
self,
inputs: PromptInputs,
prompt: PromptType,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@ -526,17 +526,17 @@ class InputPreprocessor:
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
return await self._process_encoder_decoder_prompt_async(
inputs,
prompt,
request_id=request_id,
)
if is_explicit_encoder_decoder_prompt(inputs):
if is_explicit_encoder_decoder_prompt(prompt):
raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models")
# Decoder-only operation
return await self._process_decoder_only_prompt_async(
inputs,
prompt,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,