Revert "[Core] Rename PromptInputs to PromptType, and inputs to prompt" (#8750)

This commit is contained in:
Simon Mo 2024-09-23 22:45:20 -07:00 committed by GitHub
parent 0250dd68c5
commit 3185fb0cca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 162 additions and 157 deletions

View File

@ -11,7 +11,7 @@ from tqdm import tqdm
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs
from vllm.inputs import PromptType
from vllm.inputs import PromptInputs
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser
@ -61,7 +61,7 @@ def main(args: argparse.Namespace):
dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size,
args.input_len))
dummy_prompts: List[PromptType] = [{
dummy_inputs: List[PromptInputs] = [{
"prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()]
@ -74,13 +74,13 @@ def main(args: argparse.Namespace):
],
on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir))) as p:
llm.generate(dummy_prompts,
llm.generate(dummy_inputs,
sampling_params=sampling_params,
use_tqdm=False)
print(p.key_averages())
else:
start_time = time.perf_counter()
llm.generate(dummy_prompts,
llm.generate(dummy_inputs,
sampling_params=sampling_params,
use_tqdm=False)
end_time = time.perf_counter()

View File

@ -8,7 +8,7 @@ Multi-Modality
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptType`.
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`.
Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
by following :ref:`this guide <adding_multimodal_plugin>`.

View File

@ -1,7 +1,7 @@
LLM Inputs
==========
.. autodata:: vllm.inputs.PromptType
.. autodata:: vllm.inputs.PromptInputs
.. autoclass:: vllm.inputs.TextPrompt
:show-inheritance:

View File

@ -27,7 +27,7 @@ The :class:`~vllm.LLM` class can be instantiated in much the same way as languag
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model.
To pass an image to the model, note the following in :class:`vllm.inputs.PromptType`:
To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`:
* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.

View File

@ -61,7 +61,7 @@ async def test_evil_forward(tmp_socket):
# Throws an error in first forward pass.
with pytest.raises(RAISED_ERROR):
async for _ in client.generate(prompt="Hello my name is",
async for _ in client.generate(inputs="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
@ -69,7 +69,7 @@ async def test_evil_forward(tmp_socket):
# Engine is errored, should get ENGINE_DEAD_ERROR.
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(prompt="Hello my name is",
async for _ in client.generate(inputs="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
@ -118,7 +118,7 @@ async def test_failed_health_check(tmp_socket):
# Generate call should throw ENGINE_DEAD_ERROR
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(prompt="Hello my name is",
async for _ in client.generate(inputs="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
@ -165,7 +165,7 @@ async def test_failed_abort(tmp_socket):
# with reference to the original KeyError("foo")
with pytest.raises(MQEngineDeadError) as execinfo:
async for _ in client.generate(
prompt="Hello my name is",
inputs="Hello my name is",
sampling_params=SamplingParams(max_tokens=2000),
request_id=uuid.uuid4()):
pass
@ -190,7 +190,7 @@ async def test_bad_request(tmp_socket):
# Invalid request should fail, but not crash the server.
with pytest.raises(ValueError):
async for _ in client.generate(prompt="Hello my name is",
async for _ in client.generate(inputs="Hello my name is",
sampling_params=SamplingParams(),
request_id="abcd-1",
lora_request=LoRARequest(
@ -199,7 +199,7 @@ async def test_bad_request(tmp_socket):
pass
# This request should be okay.
async for _ in client.generate(prompt="Hello my name is",
async for _ in client.generate(inputs="Hello my name is",
sampling_params=SamplingParams(),
request_id="abcd-2"):
pass

View File

@ -20,7 +20,7 @@ async def generate(
count = 0
async for out in client.generate(
request_id=request_id,
prompt="Hello my name is Robert and",
inputs="Hello my name is Robert and",
sampling_params=SamplingParams(max_tokens=num_tokens,
temperature=0)):

View File

@ -5,7 +5,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput, RequestOutput)
@ -19,7 +19,7 @@ __all__ = [
"__version_tuple__",
"LLM",
"ModelRegistry",
"PromptType",
"PromptInputs",
"TextPrompt",
"TokensPrompt",
"SamplingParams",

View File

@ -17,7 +17,7 @@ from vllm.engine.metrics_types import StatLoggerBase
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutorAsync
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType
from vllm.inputs import PromptInputs
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
@ -405,7 +405,7 @@ class _AsyncLLMEngine(LLMEngine):
async def add_request_async(
self,
request_id: str,
prompt: PromptType,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
@ -420,7 +420,7 @@ class _AsyncLLMEngine(LLMEngine):
arrival_time = time.time()
preprocessed_inputs = await self.input_preprocessor.preprocess_async(
prompt,
inputs,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
@ -777,7 +777,7 @@ class AsyncLLMEngine:
async def add_request(
self,
request_id: str,
prompt: PromptType,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
@ -797,7 +797,7 @@ class AsyncLLMEngine:
stream = self._request_tracker.add_request(
request_id,
verbose=self.log_requests,
prompt=prompt,
inputs=inputs,
params=params,
arrival_time=arrival_time or time.time(),
lora_request=lora_request,
@ -808,7 +808,7 @@ class AsyncLLMEngine:
async def generate(
self,
prompt: PromptType,
inputs: PromptInputs,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
@ -822,7 +822,8 @@ class AsyncLLMEngine:
from the LLMEngine to the caller.
Args:
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
@ -880,7 +881,7 @@ class AsyncLLMEngine:
"""
async for output in await self.add_request(
request_id,
prompt,
inputs,
sampling_params,
lora_request=lora_request,
trace_headers=trace_headers,
@ -890,7 +891,7 @@ class AsyncLLMEngine:
async def encode(
self,
prompt: PromptType,
inputs: PromptInputs,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
@ -903,7 +904,8 @@ class AsyncLLMEngine:
from the LLMEngine to the caller.
Args:
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
@ -957,7 +959,7 @@ class AsyncLLMEngine:
"""
async for output in await self.add_request(
request_id,
prompt,
inputs,
pooling_params,
lora_request=lora_request,
trace_headers=trace_headers,

View File

@ -29,7 +29,7 @@ from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
InputRegistry, LLMInputs, PromptType)
InputRegistry, LLMInputs, PromptInputs)
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -689,7 +689,7 @@ class LLMEngine:
def add_request(
self,
request_id: str,
prompt: PromptType,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
@ -704,7 +704,8 @@ class LLMEngine:
Args:
request_id: The unique ID of the request.
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
params: Parameters for sampling or pooling.
:class:`~vllm.SamplingParams` for text generation.
@ -744,7 +745,7 @@ class LLMEngine:
arrival_time = time.time()
preprocessed_inputs = self.input_preprocessor.preprocess(
prompt,
inputs,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,

View File

@ -3,7 +3,7 @@ from enum import Enum
from typing import List, Mapping, Optional, Union
from vllm import PoolingParams
from vllm.inputs import PromptType
from vllm.inputs import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
@ -23,7 +23,7 @@ class MQEngineDeadError(RuntimeError):
@dataclass
class RPCProcessRequest:
prompt: PromptType
inputs: PromptInputs
params: Union[SamplingParams, PoolingParams]
request_id: str
lora_request: Optional[LoRARequest] = None

View File

@ -25,7 +25,7 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCStartupResponse)
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.inputs import PromptType
from vllm.inputs import PromptInputs
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
@ -375,7 +375,7 @@ class MQLLMEngineClient:
def generate(
self,
prompt: PromptType,
inputs: PromptInputs,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
@ -389,7 +389,8 @@ class MQLLMEngineClient:
from the LLMEngine to the caller.
Args:
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
sampling_params: The sampling parameters of the request.
request_id: The unique id of the request.
@ -398,13 +399,13 @@ class MQLLMEngineClient:
prompt_adapter_request: Prompt Adapter request to use
for generation, if any.
"""
return self._process_request(prompt, sampling_params, request_id,
return self._process_request(inputs, sampling_params, request_id,
lora_request, trace_headers,
prompt_adapter_request)
def encode(
self,
prompt: PromptType,
inputs: PromptInputs,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
@ -417,7 +418,8 @@ class MQLLMEngineClient:
from the LLMEngine to the caller.
Args:
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
inputs: The inputs to the LLM. See
:class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
pooling_params: The pooling parameters of the request.
request_id: The unique id of the request.
@ -428,12 +430,12 @@ class MQLLMEngineClient:
The output `EmbeddingRequestOutput` objects from the LLMEngine
for the request.
"""
return self._process_request(prompt, pooling_params, request_id,
return self._process_request(inputs, pooling_params, request_id,
lora_request, trace_headers)
async def _process_request(
self,
prompt: PromptType,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
request_id: str,
lora_request: Optional[LoRARequest] = None,
@ -466,7 +468,7 @@ class MQLLMEngineClient:
request_bytes = pickle.dumps(
RPCProcessRequest(
prompt=prompt,
inputs=inputs,
params=params,
request_id=request_id,
lora_request=lora_request,

View File

@ -252,7 +252,7 @@ class MQLLMEngine:
try:
self.engine.add_request(
request_id=request_id,
prompt=request.prompt,
inputs=request.inputs,
params=request.params,
lora_request=request.lora_request,
trace_headers=request.trace_headers,

View File

@ -3,7 +3,7 @@ from typing import (AsyncGenerator, List, Mapping, Optional, Protocol,
from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptType
from vllm.inputs.data import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
@ -35,19 +35,19 @@ class EngineClient(Protocol):
def generate(
self,
prompt: PromptType,
inputs: PromptInputs,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncGenerator[RequestOutput, None]:
"""Generate outputs for a request."""
"""Generates outputs for a request"""
...
def encode(
self,
prompt: PromptType,
inputs: PromptInputs,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,

View File

@ -12,7 +12,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages)
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -293,8 +293,8 @@ class LLM:
@overload
def generate(
self,
prompts: Union[PromptType, Sequence[PromptType]],
/,
inputs: Union[PromptInputs, Sequence[PromptInputs]],
/, # We may enable `inputs` keyword after removing the old API
*,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
@ -311,7 +311,7 @@ class LLM:
)
def generate(
self,
prompts: Union[Union[PromptType, Sequence[PromptType]],
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
Optional[Union[str, List[str]]]] = None,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
@ -329,9 +329,7 @@ class LLM:
into a single list and pass it to this method.
Args:
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
inputs: A list of inputs to generate completions for.
sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters.
When it is a single value, it is applied to every prompt.
@ -357,13 +355,12 @@ class LLM:
"models (XForCausalLM, XForConditionalGeneration).")
if prompt_token_ids is not None:
parsed_prompts = self._convert_v1_inputs(
inputs = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids,
)
else:
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
prompts)
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
if isinstance(guided_options_request, dict):
if len(guided_options_request) > 1:
@ -378,7 +375,7 @@ class LLM:
sampling_params = SamplingParams()
self._validate_and_add_requests(
prompts=parsed_prompts,
inputs=inputs,
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
@ -533,9 +530,9 @@ class LLM:
conversation, mm_data = parse_chat_messages(messages, model_config,
tokenizer)
prompt_data: Union[str, List[int]]
prompt: Union[str, List[int]]
if isinstance(tokenizer, MistralTokenizer):
prompt_data = apply_mistral_chat_template(
prompt = apply_mistral_chat_template(
tokenizer,
messages=messages,
chat_template=chat_template,
@ -543,7 +540,7 @@ class LLM:
tools=tools,
)
else:
prompt_data = apply_hf_chat_template(
prompt = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=chat_template,
@ -551,17 +548,17 @@ class LLM:
tools=tools,
)
prompt: PromptType
if is_list_of(prompt_data, int):
prompt = TokensPrompt(prompt_token_ids=prompt_data)
inputs: PromptInputs
if is_list_of(prompt, int):
inputs = TokensPrompt(prompt_token_ids=prompt)
else:
prompt = TextPrompt(prompt=prompt_data)
inputs = TextPrompt(prompt=prompt)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
inputs["multi_modal_data"] = mm_data
return self.generate(
prompt,
inputs,
sampling_params=sampling_params,
use_tqdm=use_tqdm,
lora_request=lora_request,
@ -631,8 +628,8 @@ class LLM:
@overload
def encode(
self,
prompts: Union[PromptType, Sequence[PromptType]],
/,
inputs: Union[PromptInputs, Sequence[PromptInputs]],
/, # We may enable `inputs` keyword after removing the old API
*,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
@ -649,7 +646,7 @@ class LLM:
)
def encode(
self,
prompts: Union[Union[PromptType, Sequence[PromptType]],
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
Optional[Union[str, List[str]]]] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
@ -665,9 +662,9 @@ class LLM:
into a single list and pass it to this method.
Args:
prompts: The prompts to the LLM. You may pass a sequence of prompts
for batch inference. See :class:`~vllm.inputs.PromptType`
for more details about the format of each prompts.
inputs: The inputs to the LLM. You may pass a sequence of inputs for
batch inference. See :class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
use_tqdm: Whether to use tqdm to display the progress bar.
@ -690,20 +687,19 @@ class LLM:
)
if prompt_token_ids is not None:
parsed_prompts = self._convert_v1_inputs(
inputs = self._convert_v1_inputs(
prompts=cast(Optional[Union[str, List[str]]], prompts),
prompt_token_ids=prompt_token_ids,
)
else:
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
prompts)
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
if pooling_params is None:
# Use default pooling params.
pooling_params = PoolingParams()
self._validate_and_add_requests(
prompts=parsed_prompts,
inputs=inputs,
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
@ -747,9 +743,9 @@ class LLM:
raise ValueError("Either prompts or prompt_token_ids must be "
"provided.")
parsed_prompts: List[PromptType] = []
inputs: List[PromptInputs] = []
for i in range(num_requests):
item: PromptType
item: PromptInputs
if prompts is not None:
item = TextPrompt(prompt=prompts[i])
@ -758,24 +754,24 @@ class LLM:
else:
raise AssertionError
parsed_prompts.append(item)
inputs.append(item)
return parsed_prompts
return inputs
def _validate_and_add_requests(
self,
prompts: Union[PromptType, Sequence[PromptType]],
inputs: Union[PromptInputs, Sequence[PromptInputs]],
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]],
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
prompt_adapter_request: Optional[PromptAdapterRequest],
guided_options: Optional[GuidedDecodingRequest] = None,
) -> None:
if isinstance(prompts, (str, dict)):
if isinstance(inputs, (str, dict)):
# Convert a single prompt to a list.
prompts = [prompts]
inputs = [inputs]
num_requests = len(prompts)
num_requests = len(inputs)
if isinstance(params, list) and len(params) != num_requests:
raise ValueError("The lengths of prompts and params "
"must be the same.")
@ -792,9 +788,9 @@ class LLM:
sp.output_kind = RequestOutputKind.FINAL_ONLY
# Add requests to the engine.
for i, prompt in enumerate(prompts):
for i, request_inputs in enumerate(inputs):
self._add_request(
prompt,
request_inputs,
params[i] if isinstance(params, Sequence) else params,
lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request,
@ -803,7 +799,7 @@ class LLM:
def _add_request(
self,
prompt: PromptType,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@ -811,7 +807,7 @@ class LLM:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(
request_id,
prompt,
inputs,
params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,

View File

@ -1,5 +1,5 @@
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
LLMInputs, PromptType, SingletonPrompt, TextPrompt,
LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt,
TokensPrompt, build_explicit_enc_dec_prompt,
to_enc_dec_tuple_list, zip_enc_dec_prompts)
from .registry import InputContext, InputRegistry
@ -16,8 +16,8 @@ See also:
__all__ = [
"TextPrompt",
"TokensPrompt",
"PromptType",
"SingletonPrompt",
"PromptInputs",
"SingletonPromptInputs",
"ExplicitEncoderDecoderPrompt",
"LLMInputs",
"EncoderDecoderLLMInputs",

View File

@ -33,7 +33,7 @@ class TokensPrompt(TypedDict):
"""
SingletonPrompt = Union[str, TextPrompt, TokensPrompt]
SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt]
"""
Set of possible schemas for a single LLM input:
@ -46,7 +46,7 @@ which may be utilized for encoder/decoder models when
the user desires to express both the encoder & decoder
prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt`
A prompt of type :class:`SingletonPromptType` may be employed
A prompt of type :class:`SingletonPromptInputs` may be employed
as (1) input to a decoder-only model, (2) input to
the encoder of an encoder/decoder model, in the scenario
where the decoder-prompt is not specified explicitly, or
@ -55,12 +55,12 @@ more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt`
"""
_T1_co = TypeVar("_T1_co",
bound=SingletonPrompt,
default=SingletonPrompt,
bound=SingletonPromptInputs,
default=SingletonPromptInputs,
covariant=True)
_T2_co = TypeVar("_T2_co",
bound=SingletonPrompt,
default=SingletonPrompt,
bound=SingletonPromptInputs,
default=SingletonPromptInputs,
covariant=True)
@ -72,7 +72,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
The encoder and decoder prompts, respectively,
may formatted according to any of the
:class:`SingletonPromptType` schemas, and are not
:class:`SingletonPromptInputs` schemas, and are not
required to have the same schema.
Only the encoder prompt may have multi-modal data.
@ -81,7 +81,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
be used as an input to a decoder-only model,
and that the `encoder_prompt` and `decoder_prompt`
fields of this data structure themselves must be
:class:`SingletonPromptType` instances.
:class:`SingletonPromptInputs` instances.
"""
encoder_prompt: _T1_co
@ -89,7 +89,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
decoder_prompt: Optional[_T2_co]
PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt]
PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt]
"""
Set of possible schemas for an LLM input, including
both decoder-only and encoder/decoder input types:
@ -140,8 +140,12 @@ class EncoderDecoderLLMInputs(LLMInputs):
"""
_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt)
_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)
_T1 = TypeVar("_T1",
bound=SingletonPromptInputs,
default=SingletonPromptInputs)
_T2 = TypeVar("_T2",
bound=SingletonPromptInputs,
default=SingletonPromptInputs)
def build_explicit_enc_dec_prompt(

View File

@ -5,7 +5,7 @@ from typing_extensions import TypeIs
from vllm.utils import is_list_of
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
LLMInputs, PromptType, SingletonPrompt, TextPrompt,
LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt,
TokensPrompt)
@ -81,23 +81,23 @@ class ParsedTokensPrompt(TypedDict):
def parse_singleton_prompt(
prompt: SingletonPrompt,
inputs: SingletonPromptInputs,
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]:
if isinstance(prompt, str):
return ParsedStrPrompt(type="str", content=prompt)
elif isinstance(prompt, dict):
if "prompt_token_ids" in prompt:
if isinstance(inputs, str):
return ParsedStrPrompt(type="str", content=inputs)
elif isinstance(inputs, dict):
if "prompt_token_ids" in inputs:
return ParsedTokensPrompt(type="tokens",
content=prompt) # type: ignore
elif "prompt" in prompt:
return ParsedTextPrompt(type="text", content=prompt)
content=inputs) # type: ignore
elif "prompt" in inputs:
return ParsedTextPrompt(type="text", content=inputs)
raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt")
def is_explicit_encoder_decoder_prompt(
prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
return isinstance(prompt, dict) and "encoder_prompt" in prompt
inputs: PromptInputs) -> TypeIs[ExplicitEncoderDecoderPrompt]:
return isinstance(inputs, dict) and "encoder_prompt" in inputs
def is_valid_encoder_decoder_llm_inputs(

View File

@ -9,8 +9,8 @@ from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptType,
SingletonPrompt)
from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
SingletonPromptInputs)
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
if TYPE_CHECKING:
@ -206,7 +206,7 @@ class InputPreprocessor:
def _extract_prompt_components(
self,
prompt: SingletonPrompt,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> PromptComponents:
@ -216,7 +216,7 @@ class InputPreprocessor:
Arguments:
* request_id
* prompt: single encoder or decoder input prompt
* inputs: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts
Returns:
@ -226,24 +226,24 @@ class InputPreprocessor:
* multi_modal_data
'''
parsed = parse_singleton_prompt(prompt)
parsed = parse_singleton_prompt(inputs)
if parsed["type"] == "str":
prompt_text = parsed["content"]
prompt = parsed["content"]
prompt_token_ids = self._tokenize_prompt(
prompt_text,
prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = None
elif parsed["type"] == "tokens":
prompt_text = None
prompt = None
prompt_token_ids = parsed["content"]["prompt_token_ids"]
multi_modal_data = parsed["content"].get("multi_modal_data")
elif parsed["type"] == "text":
prompt_text = parsed["content"]["prompt"]
prompt = parsed["content"]["prompt"]
prompt_token_ids = self._tokenize_prompt(
prompt_text,
prompt,
request_id=request_id,
lora_request=lora_request,
)
@ -251,33 +251,33 @@ class InputPreprocessor:
else:
assert_never(parsed)
return prompt_text, prompt_token_ids, multi_modal_data
return prompt, prompt_token_ids, multi_modal_data
async def _extract_prompt_components_async(
self,
prompt: SingletonPrompt,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
) -> PromptComponents:
"""Async version of :meth:`_extract_prompt_components`."""
parsed = parse_singleton_prompt(prompt)
parsed = parse_singleton_prompt(inputs)
if parsed["type"] == "str":
prompt_text = parsed["content"]
prompt = parsed["content"]
prompt_token_ids = await self._tokenize_prompt_async(
prompt_text,
prompt,
request_id=request_id,
lora_request=lora_request,
)
multi_modal_data = None
elif parsed["type"] == "tokens":
prompt_text = None
prompt = None
prompt_token_ids = parsed["content"]["prompt_token_ids"]
multi_modal_data = parsed["content"].get("multi_modal_data")
elif parsed["type"] == "text":
prompt_text = parsed["content"]["prompt"]
prompt = parsed["content"]["prompt"]
prompt_token_ids = await self._tokenize_prompt_async(
prompt_text,
prompt,
request_id=request_id,
lora_request=lora_request,
)
@ -285,7 +285,7 @@ class InputPreprocessor:
else:
assert_never(parsed)
return prompt_text, prompt_token_ids, multi_modal_data
return prompt, prompt_token_ids, multi_modal_data
def _build_enc_dec_llm_inputs(
self,
@ -311,7 +311,7 @@ class InputPreprocessor:
def _process_encoder_decoder_prompt(
self,
prompt: PromptType,
inputs: PromptInputs,
request_id: str,
) -> EncoderDecoderLLMInputs:
'''
@ -339,7 +339,7 @@ class InputPreprocessor:
Arguments:
* prompt: an input prompt
* inputs: an input prompt
* request_id
Returns:
@ -350,13 +350,13 @@ class InputPreprocessor:
encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents
if is_explicit_encoder_decoder_prompt(prompt):
if is_explicit_encoder_decoder_prompt(inputs):
encoder_comps = self._extract_prompt_components(
prompt["encoder_prompt"],
inputs["encoder_prompt"],
request_id=request_id,
)
if (decoder_input := prompt["decoder_prompt"]) is None:
if (decoder_input := inputs["decoder_prompt"]) is None:
decoder_comps = None, None, None
else:
decoder_comps = self._extract_prompt_components(
@ -365,7 +365,7 @@ class InputPreprocessor:
)
else:
encoder_comps = self._extract_prompt_components(
prompt,
inputs,
request_id=request_id,
)
@ -375,20 +375,20 @@ class InputPreprocessor:
async def _process_encoder_decoder_prompt_async(
self,
prompt: PromptType,
inputs: PromptInputs,
request_id: str,
) -> EncoderDecoderLLMInputs:
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_comps: PromptComponents
decoder_comps: DecoderPromptComponents
if is_explicit_encoder_decoder_prompt(prompt):
if is_explicit_encoder_decoder_prompt(inputs):
encoder_task = self._extract_prompt_components_async(
prompt["encoder_prompt"],
inputs["encoder_prompt"],
request_id=request_id,
)
if (decoder_input := prompt["decoder_prompt"]) is None:
if (decoder_input := inputs["decoder_prompt"]) is None:
encoder_comps = await encoder_task
decoder_comps = None, None, None
else:
@ -401,7 +401,7 @@ class InputPreprocessor:
encoder_task, decoder_task)
else:
encoder_comps = await self._extract_prompt_components_async(
prompt,
inputs,
request_id=request_id,
)
@ -425,7 +425,7 @@ class InputPreprocessor:
def _process_decoder_only_prompt(
self,
prompt: SingletonPrompt,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@ -436,7 +436,7 @@ class InputPreprocessor:
Arguments:
* prompt: input prompt
* inputs: input prompt
* request_id
* lora_request
* prompt_adapter_request
@ -447,7 +447,7 @@ class InputPreprocessor:
'''
prompt_comps = self._extract_prompt_components(
prompt,
inputs,
request_id=request_id,
lora_request=lora_request,
)
@ -459,14 +459,14 @@ class InputPreprocessor:
async def _process_decoder_only_prompt_async(
self,
prompt: SingletonPrompt,
inputs: SingletonPromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
"""Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps = await self._extract_prompt_components_async(
prompt,
inputs,
request_id=request_id,
lora_request=lora_request,
)
@ -478,7 +478,7 @@ class InputPreprocessor:
def preprocess(
self,
prompt: PromptType,
inputs: PromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@ -488,17 +488,17 @@ class InputPreprocessor:
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
return self._process_encoder_decoder_prompt(
prompt,
inputs,
request_id=request_id,
)
if is_explicit_encoder_decoder_prompt(prompt):
if is_explicit_encoder_decoder_prompt(inputs):
raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models")
# Decoder-only operation
return self._process_decoder_only_prompt(
prompt,
inputs,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
@ -506,7 +506,7 @@ class InputPreprocessor:
async def preprocess_async(
self,
prompt: PromptType,
inputs: PromptInputs,
request_id: str,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
@ -516,17 +516,17 @@ class InputPreprocessor:
# Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder
return await self._process_encoder_decoder_prompt_async(
prompt,
inputs,
request_id=request_id,
)
if is_explicit_encoder_decoder_prompt(prompt):
if is_explicit_encoder_decoder_prompt(inputs):
raise ValueError("Cannot pass encoder-decoder prompt "
"to decoder-only models")
# Decoder-only operation
return await self._process_decoder_only_prompt_async(
prompt,
inputs,
request_id=request_id,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,