Revert "[Core] Rename PromptInputs
to PromptType
, and inputs
to prompt
" (#8750)
This commit is contained in:
parent
0250dd68c5
commit
3185fb0cca
@ -11,7 +11,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs
|
from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs
|
||||||
from vllm.inputs import PromptType
|
from vllm.inputs import PromptInputs
|
||||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
@ -61,7 +61,7 @@ def main(args: argparse.Namespace):
|
|||||||
dummy_prompt_token_ids = np.random.randint(10000,
|
dummy_prompt_token_ids = np.random.randint(10000,
|
||||||
size=(args.batch_size,
|
size=(args.batch_size,
|
||||||
args.input_len))
|
args.input_len))
|
||||||
dummy_prompts: List[PromptType] = [{
|
dummy_inputs: List[PromptInputs] = [{
|
||||||
"prompt_token_ids": batch
|
"prompt_token_ids": batch
|
||||||
} for batch in dummy_prompt_token_ids.tolist()]
|
} for batch in dummy_prompt_token_ids.tolist()]
|
||||||
|
|
||||||
@ -74,13 +74,13 @@ def main(args: argparse.Namespace):
|
|||||||
],
|
],
|
||||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||||
str(profile_dir))) as p:
|
str(profile_dir))) as p:
|
||||||
llm.generate(dummy_prompts,
|
llm.generate(dummy_inputs,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
use_tqdm=False)
|
use_tqdm=False)
|
||||||
print(p.key_averages())
|
print(p.key_averages())
|
||||||
else:
|
else:
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
llm.generate(dummy_prompts,
|
llm.generate(dummy_inputs,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
use_tqdm=False)
|
use_tqdm=False)
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
|
@ -8,7 +8,7 @@ Multi-Modality
|
|||||||
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
|
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
|
||||||
|
|
||||||
Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
|
Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
|
||||||
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptType`.
|
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`.
|
||||||
|
|
||||||
Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
|
Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
|
||||||
by following :ref:`this guide <adding_multimodal_plugin>`.
|
by following :ref:`this guide <adding_multimodal_plugin>`.
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
LLM Inputs
|
LLM Inputs
|
||||||
==========
|
==========
|
||||||
|
|
||||||
.. autodata:: vllm.inputs.PromptType
|
.. autodata:: vllm.inputs.PromptInputs
|
||||||
|
|
||||||
.. autoclass:: vllm.inputs.TextPrompt
|
.. autoclass:: vllm.inputs.TextPrompt
|
||||||
:show-inheritance:
|
:show-inheritance:
|
||||||
|
@ -27,7 +27,7 @@ The :class:`~vllm.LLM` class can be instantiated in much the same way as languag
|
|||||||
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
|
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
|
||||||
the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model.
|
the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model.
|
||||||
|
|
||||||
To pass an image to the model, note the following in :class:`vllm.inputs.PromptType`:
|
To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`:
|
||||||
|
|
||||||
* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
|
* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
|
||||||
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.
|
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.
|
||||||
|
@ -61,7 +61,7 @@ async def test_evil_forward(tmp_socket):
|
|||||||
|
|
||||||
# Throws an error in first forward pass.
|
# Throws an error in first forward pass.
|
||||||
with pytest.raises(RAISED_ERROR):
|
with pytest.raises(RAISED_ERROR):
|
||||||
async for _ in client.generate(prompt="Hello my name is",
|
async for _ in client.generate(inputs="Hello my name is",
|
||||||
sampling_params=SamplingParams(),
|
sampling_params=SamplingParams(),
|
||||||
request_id=uuid.uuid4()):
|
request_id=uuid.uuid4()):
|
||||||
pass
|
pass
|
||||||
@ -69,7 +69,7 @@ async def test_evil_forward(tmp_socket):
|
|||||||
|
|
||||||
# Engine is errored, should get ENGINE_DEAD_ERROR.
|
# Engine is errored, should get ENGINE_DEAD_ERROR.
|
||||||
with pytest.raises(MQEngineDeadError):
|
with pytest.raises(MQEngineDeadError):
|
||||||
async for _ in client.generate(prompt="Hello my name is",
|
async for _ in client.generate(inputs="Hello my name is",
|
||||||
sampling_params=SamplingParams(),
|
sampling_params=SamplingParams(),
|
||||||
request_id=uuid.uuid4()):
|
request_id=uuid.uuid4()):
|
||||||
pass
|
pass
|
||||||
@ -118,7 +118,7 @@ async def test_failed_health_check(tmp_socket):
|
|||||||
|
|
||||||
# Generate call should throw ENGINE_DEAD_ERROR
|
# Generate call should throw ENGINE_DEAD_ERROR
|
||||||
with pytest.raises(MQEngineDeadError):
|
with pytest.raises(MQEngineDeadError):
|
||||||
async for _ in client.generate(prompt="Hello my name is",
|
async for _ in client.generate(inputs="Hello my name is",
|
||||||
sampling_params=SamplingParams(),
|
sampling_params=SamplingParams(),
|
||||||
request_id=uuid.uuid4()):
|
request_id=uuid.uuid4()):
|
||||||
pass
|
pass
|
||||||
@ -165,7 +165,7 @@ async def test_failed_abort(tmp_socket):
|
|||||||
# with reference to the original KeyError("foo")
|
# with reference to the original KeyError("foo")
|
||||||
with pytest.raises(MQEngineDeadError) as execinfo:
|
with pytest.raises(MQEngineDeadError) as execinfo:
|
||||||
async for _ in client.generate(
|
async for _ in client.generate(
|
||||||
prompt="Hello my name is",
|
inputs="Hello my name is",
|
||||||
sampling_params=SamplingParams(max_tokens=2000),
|
sampling_params=SamplingParams(max_tokens=2000),
|
||||||
request_id=uuid.uuid4()):
|
request_id=uuid.uuid4()):
|
||||||
pass
|
pass
|
||||||
@ -190,7 +190,7 @@ async def test_bad_request(tmp_socket):
|
|||||||
|
|
||||||
# Invalid request should fail, but not crash the server.
|
# Invalid request should fail, but not crash the server.
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
async for _ in client.generate(prompt="Hello my name is",
|
async for _ in client.generate(inputs="Hello my name is",
|
||||||
sampling_params=SamplingParams(),
|
sampling_params=SamplingParams(),
|
||||||
request_id="abcd-1",
|
request_id="abcd-1",
|
||||||
lora_request=LoRARequest(
|
lora_request=LoRARequest(
|
||||||
@ -199,7 +199,7 @@ async def test_bad_request(tmp_socket):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# This request should be okay.
|
# This request should be okay.
|
||||||
async for _ in client.generate(prompt="Hello my name is",
|
async for _ in client.generate(inputs="Hello my name is",
|
||||||
sampling_params=SamplingParams(),
|
sampling_params=SamplingParams(),
|
||||||
request_id="abcd-2"):
|
request_id="abcd-2"):
|
||||||
pass
|
pass
|
||||||
|
@ -20,7 +20,7 @@ async def generate(
|
|||||||
count = 0
|
count = 0
|
||||||
async for out in client.generate(
|
async for out in client.generate(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
prompt="Hello my name is Robert and",
|
inputs="Hello my name is Robert and",
|
||||||
sampling_params=SamplingParams(max_tokens=num_tokens,
|
sampling_params=SamplingParams(max_tokens=num_tokens,
|
||||||
temperature=0)):
|
temperature=0)):
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|||||||
from vllm.engine.llm_engine import LLMEngine
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
from vllm.entrypoints.llm import LLM
|
from vllm.entrypoints.llm import LLM
|
||||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||||
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
|
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
from vllm.outputs import (CompletionOutput, EmbeddingOutput,
|
from vllm.outputs import (CompletionOutput, EmbeddingOutput,
|
||||||
EmbeddingRequestOutput, RequestOutput)
|
EmbeddingRequestOutput, RequestOutput)
|
||||||
@ -19,7 +19,7 @@ __all__ = [
|
|||||||
"__version_tuple__",
|
"__version_tuple__",
|
||||||
"LLM",
|
"LLM",
|
||||||
"ModelRegistry",
|
"ModelRegistry",
|
||||||
"PromptType",
|
"PromptInputs",
|
||||||
"TextPrompt",
|
"TextPrompt",
|
||||||
"TokensPrompt",
|
"TokensPrompt",
|
||||||
"SamplingParams",
|
"SamplingParams",
|
||||||
|
@ -17,7 +17,7 @@ from vllm.engine.metrics_types import StatLoggerBase
|
|||||||
from vllm.executor.executor_base import ExecutorAsyncBase
|
from vllm.executor.executor_base import ExecutorAsyncBase
|
||||||
from vllm.executor.gpu_executor import GPUExecutorAsync
|
from vllm.executor.gpu_executor import GPUExecutorAsync
|
||||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||||
from vllm.inputs import PromptType
|
from vllm.inputs import PromptInputs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
@ -405,7 +405,7 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
async def add_request_async(
|
async def add_request_async(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
prompt: PromptType,
|
inputs: PromptInputs,
|
||||||
params: Union[SamplingParams, PoolingParams],
|
params: Union[SamplingParams, PoolingParams],
|
||||||
arrival_time: Optional[float] = None,
|
arrival_time: Optional[float] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
@ -420,7 +420,7 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
arrival_time = time.time()
|
arrival_time = time.time()
|
||||||
|
|
||||||
preprocessed_inputs = await self.input_preprocessor.preprocess_async(
|
preprocessed_inputs = await self.input_preprocessor.preprocess_async(
|
||||||
prompt,
|
inputs,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
prompt_adapter_request=prompt_adapter_request,
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
@ -777,7 +777,7 @@ class AsyncLLMEngine:
|
|||||||
async def add_request(
|
async def add_request(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
prompt: PromptType,
|
inputs: PromptInputs,
|
||||||
params: Union[SamplingParams, PoolingParams],
|
params: Union[SamplingParams, PoolingParams],
|
||||||
arrival_time: Optional[float] = None,
|
arrival_time: Optional[float] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
@ -797,7 +797,7 @@ class AsyncLLMEngine:
|
|||||||
stream = self._request_tracker.add_request(
|
stream = self._request_tracker.add_request(
|
||||||
request_id,
|
request_id,
|
||||||
verbose=self.log_requests,
|
verbose=self.log_requests,
|
||||||
prompt=prompt,
|
inputs=inputs,
|
||||||
params=params,
|
params=params,
|
||||||
arrival_time=arrival_time or time.time(),
|
arrival_time=arrival_time or time.time(),
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
@ -808,7 +808,7 @@ class AsyncLLMEngine:
|
|||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
prompt: PromptType,
|
inputs: PromptInputs,
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
@ -822,7 +822,8 @@ class AsyncLLMEngine:
|
|||||||
from the LLMEngine to the caller.
|
from the LLMEngine to the caller.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
|
inputs: The inputs to the LLM. See
|
||||||
|
:class:`~vllm.inputs.PromptInputs`
|
||||||
for more details about the format of each input.
|
for more details about the format of each input.
|
||||||
sampling_params: The sampling parameters of the request.
|
sampling_params: The sampling parameters of the request.
|
||||||
request_id: The unique id of the request.
|
request_id: The unique id of the request.
|
||||||
@ -880,7 +881,7 @@ class AsyncLLMEngine:
|
|||||||
"""
|
"""
|
||||||
async for output in await self.add_request(
|
async for output in await self.add_request(
|
||||||
request_id,
|
request_id,
|
||||||
prompt,
|
inputs,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
trace_headers=trace_headers,
|
trace_headers=trace_headers,
|
||||||
@ -890,7 +891,7 @@ class AsyncLLMEngine:
|
|||||||
|
|
||||||
async def encode(
|
async def encode(
|
||||||
self,
|
self,
|
||||||
prompt: PromptType,
|
inputs: PromptInputs,
|
||||||
pooling_params: PoolingParams,
|
pooling_params: PoolingParams,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
@ -903,7 +904,8 @@ class AsyncLLMEngine:
|
|||||||
from the LLMEngine to the caller.
|
from the LLMEngine to the caller.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
|
inputs: The inputs to the LLM. See
|
||||||
|
:class:`~vllm.inputs.PromptInputs`
|
||||||
for more details about the format of each input.
|
for more details about the format of each input.
|
||||||
pooling_params: The pooling parameters of the request.
|
pooling_params: The pooling parameters of the request.
|
||||||
request_id: The unique id of the request.
|
request_id: The unique id of the request.
|
||||||
@ -957,7 +959,7 @@ class AsyncLLMEngine:
|
|||||||
"""
|
"""
|
||||||
async for output in await self.add_request(
|
async for output in await self.add_request(
|
||||||
request_id,
|
request_id,
|
||||||
prompt,
|
inputs,
|
||||||
pooling_params,
|
pooling_params,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
trace_headers=trace_headers,
|
trace_headers=trace_headers,
|
||||||
|
@ -29,7 +29,7 @@ from vllm.executor.executor_base import ExecutorBase
|
|||||||
from vllm.executor.gpu_executor import GPUExecutor
|
from vllm.executor.gpu_executor import GPUExecutor
|
||||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||||
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
|
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
|
||||||
InputRegistry, LLMInputs, PromptType)
|
InputRegistry, LLMInputs, PromptInputs)
|
||||||
from vllm.inputs.preprocess import InputPreprocessor
|
from vllm.inputs.preprocess import InputPreprocessor
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -689,7 +689,7 @@ class LLMEngine:
|
|||||||
def add_request(
|
def add_request(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
prompt: PromptType,
|
inputs: PromptInputs,
|
||||||
params: Union[SamplingParams, PoolingParams],
|
params: Union[SamplingParams, PoolingParams],
|
||||||
arrival_time: Optional[float] = None,
|
arrival_time: Optional[float] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
@ -704,7 +704,8 @@ class LLMEngine:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
request_id: The unique ID of the request.
|
request_id: The unique ID of the request.
|
||||||
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
|
inputs: The inputs to the LLM. See
|
||||||
|
:class:`~vllm.inputs.PromptInputs`
|
||||||
for more details about the format of each input.
|
for more details about the format of each input.
|
||||||
params: Parameters for sampling or pooling.
|
params: Parameters for sampling or pooling.
|
||||||
:class:`~vllm.SamplingParams` for text generation.
|
:class:`~vllm.SamplingParams` for text generation.
|
||||||
@ -744,7 +745,7 @@ class LLMEngine:
|
|||||||
arrival_time = time.time()
|
arrival_time = time.time()
|
||||||
|
|
||||||
preprocessed_inputs = self.input_preprocessor.preprocess(
|
preprocessed_inputs = self.input_preprocessor.preprocess(
|
||||||
prompt,
|
inputs,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
prompt_adapter_request=prompt_adapter_request,
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
|
@ -3,7 +3,7 @@ from enum import Enum
|
|||||||
from typing import List, Mapping, Optional, Union
|
from typing import List, Mapping, Optional, Union
|
||||||
|
|
||||||
from vllm import PoolingParams
|
from vllm import PoolingParams
|
||||||
from vllm.inputs import PromptType
|
from vllm.inputs import PromptInputs
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
@ -23,7 +23,7 @@ class MQEngineDeadError(RuntimeError):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RPCProcessRequest:
|
class RPCProcessRequest:
|
||||||
prompt: PromptType
|
inputs: PromptInputs
|
||||||
params: Union[SamplingParams, PoolingParams]
|
params: Union[SamplingParams, PoolingParams]
|
||||||
request_id: str
|
request_id: str
|
||||||
lora_request: Optional[LoRARequest] = None
|
lora_request: Optional[LoRARequest] = None
|
||||||
|
@ -25,7 +25,7 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
|||||||
RPCStartupResponse)
|
RPCStartupResponse)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.envs import VLLM_RPC_TIMEOUT
|
from vllm.envs import VLLM_RPC_TIMEOUT
|
||||||
from vllm.inputs import PromptType
|
from vllm.inputs import PromptInputs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||||
@ -375,7 +375,7 @@ class MQLLMEngineClient:
|
|||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompt: PromptType,
|
inputs: PromptInputs,
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
@ -389,7 +389,8 @@ class MQLLMEngineClient:
|
|||||||
from the LLMEngine to the caller.
|
from the LLMEngine to the caller.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
|
inputs: The inputs to the LLM. See
|
||||||
|
:class:`~vllm.inputs.PromptInputs`
|
||||||
for more details about the format of each input.
|
for more details about the format of each input.
|
||||||
sampling_params: The sampling parameters of the request.
|
sampling_params: The sampling parameters of the request.
|
||||||
request_id: The unique id of the request.
|
request_id: The unique id of the request.
|
||||||
@ -398,13 +399,13 @@ class MQLLMEngineClient:
|
|||||||
prompt_adapter_request: Prompt Adapter request to use
|
prompt_adapter_request: Prompt Adapter request to use
|
||||||
for generation, if any.
|
for generation, if any.
|
||||||
"""
|
"""
|
||||||
return self._process_request(prompt, sampling_params, request_id,
|
return self._process_request(inputs, sampling_params, request_id,
|
||||||
lora_request, trace_headers,
|
lora_request, trace_headers,
|
||||||
prompt_adapter_request)
|
prompt_adapter_request)
|
||||||
|
|
||||||
def encode(
|
def encode(
|
||||||
self,
|
self,
|
||||||
prompt: PromptType,
|
inputs: PromptInputs,
|
||||||
pooling_params: PoolingParams,
|
pooling_params: PoolingParams,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
@ -417,7 +418,8 @@ class MQLLMEngineClient:
|
|||||||
from the LLMEngine to the caller.
|
from the LLMEngine to the caller.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
|
inputs: The inputs to the LLM. See
|
||||||
|
:class:`~vllm.inputs.PromptInputs`
|
||||||
for more details about the format of each input.
|
for more details about the format of each input.
|
||||||
pooling_params: The pooling parameters of the request.
|
pooling_params: The pooling parameters of the request.
|
||||||
request_id: The unique id of the request.
|
request_id: The unique id of the request.
|
||||||
@ -428,12 +430,12 @@ class MQLLMEngineClient:
|
|||||||
The output `EmbeddingRequestOutput` objects from the LLMEngine
|
The output `EmbeddingRequestOutput` objects from the LLMEngine
|
||||||
for the request.
|
for the request.
|
||||||
"""
|
"""
|
||||||
return self._process_request(prompt, pooling_params, request_id,
|
return self._process_request(inputs, pooling_params, request_id,
|
||||||
lora_request, trace_headers)
|
lora_request, trace_headers)
|
||||||
|
|
||||||
async def _process_request(
|
async def _process_request(
|
||||||
self,
|
self,
|
||||||
prompt: PromptType,
|
inputs: PromptInputs,
|
||||||
params: Union[SamplingParams, PoolingParams],
|
params: Union[SamplingParams, PoolingParams],
|
||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
@ -466,7 +468,7 @@ class MQLLMEngineClient:
|
|||||||
|
|
||||||
request_bytes = pickle.dumps(
|
request_bytes = pickle.dumps(
|
||||||
RPCProcessRequest(
|
RPCProcessRequest(
|
||||||
prompt=prompt,
|
inputs=inputs,
|
||||||
params=params,
|
params=params,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
|
@ -252,7 +252,7 @@ class MQLLMEngine:
|
|||||||
try:
|
try:
|
||||||
self.engine.add_request(
|
self.engine.add_request(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
prompt=request.prompt,
|
inputs=request.inputs,
|
||||||
params=request.params,
|
params=request.params,
|
||||||
lora_request=request.lora_request,
|
lora_request=request.lora_request,
|
||||||
trace_headers=request.trace_headers,
|
trace_headers=request.trace_headers,
|
||||||
|
@ -3,7 +3,7 @@ from typing import (AsyncGenerator, List, Mapping, Optional, Protocol,
|
|||||||
|
|
||||||
from vllm.config import DecodingConfig, ModelConfig
|
from vllm.config import DecodingConfig, ModelConfig
|
||||||
from vllm.core.scheduler import SchedulerOutputs
|
from vllm.core.scheduler import SchedulerOutputs
|
||||||
from vllm.inputs.data import PromptType
|
from vllm.inputs.data import PromptInputs
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||||
@ -35,19 +35,19 @@ class EngineClient(Protocol):
|
|||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompt: PromptType,
|
inputs: PromptInputs,
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||||
) -> AsyncGenerator[RequestOutput, None]:
|
) -> AsyncGenerator[RequestOutput, None]:
|
||||||
"""Generate outputs for a request."""
|
"""Generates outputs for a request"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def encode(
|
def encode(
|
||||||
self,
|
self,
|
||||||
prompt: PromptType,
|
inputs: PromptInputs,
|
||||||
pooling_params: PoolingParams,
|
pooling_params: PoolingParams,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
@ -12,7 +12,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
|||||||
apply_hf_chat_template,
|
apply_hf_chat_template,
|
||||||
apply_mistral_chat_template,
|
apply_mistral_chat_template,
|
||||||
parse_chat_messages)
|
parse_chat_messages)
|
||||||
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
|
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
|
||||||
from vllm.inputs.parse import parse_and_batch_prompt
|
from vllm.inputs.parse import parse_and_batch_prompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -293,8 +293,8 @@ class LLM:
|
|||||||
@overload
|
@overload
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompts: Union[PromptType, Sequence[PromptType]],
|
inputs: Union[PromptInputs, Sequence[PromptInputs]],
|
||||||
/,
|
/, # We may enable `inputs` keyword after removing the old API
|
||||||
*,
|
*,
|
||||||
sampling_params: Optional[Union[SamplingParams,
|
sampling_params: Optional[Union[SamplingParams,
|
||||||
Sequence[SamplingParams]]] = None,
|
Sequence[SamplingParams]]] = None,
|
||||||
@ -311,7 +311,7 @@ class LLM:
|
|||||||
)
|
)
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompts: Union[Union[PromptType, Sequence[PromptType]],
|
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
|
||||||
Optional[Union[str, List[str]]]] = None,
|
Optional[Union[str, List[str]]]] = None,
|
||||||
sampling_params: Optional[Union[SamplingParams,
|
sampling_params: Optional[Union[SamplingParams,
|
||||||
Sequence[SamplingParams]]] = None,
|
Sequence[SamplingParams]]] = None,
|
||||||
@ -329,9 +329,7 @@ class LLM:
|
|||||||
into a single list and pass it to this method.
|
into a single list and pass it to this method.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompts: The prompts to the LLM. You may pass a sequence of prompts
|
inputs: A list of inputs to generate completions for.
|
||||||
for batch inference. See :class:`~vllm.inputs.PromptType`
|
|
||||||
for more details about the format of each prompts.
|
|
||||||
sampling_params: The sampling parameters for text generation. If
|
sampling_params: The sampling parameters for text generation. If
|
||||||
None, we use the default sampling parameters.
|
None, we use the default sampling parameters.
|
||||||
When it is a single value, it is applied to every prompt.
|
When it is a single value, it is applied to every prompt.
|
||||||
@ -357,13 +355,12 @@ class LLM:
|
|||||||
"models (XForCausalLM, XForConditionalGeneration).")
|
"models (XForCausalLM, XForConditionalGeneration).")
|
||||||
|
|
||||||
if prompt_token_ids is not None:
|
if prompt_token_ids is not None:
|
||||||
parsed_prompts = self._convert_v1_inputs(
|
inputs = self._convert_v1_inputs(
|
||||||
prompts=cast(Optional[Union[str, List[str]]], prompts),
|
prompts=cast(Optional[Union[str, List[str]]], prompts),
|
||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
|
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
|
||||||
prompts)
|
|
||||||
|
|
||||||
if isinstance(guided_options_request, dict):
|
if isinstance(guided_options_request, dict):
|
||||||
if len(guided_options_request) > 1:
|
if len(guided_options_request) > 1:
|
||||||
@ -378,7 +375,7 @@ class LLM:
|
|||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
|
|
||||||
self._validate_and_add_requests(
|
self._validate_and_add_requests(
|
||||||
prompts=parsed_prompts,
|
inputs=inputs,
|
||||||
params=sampling_params,
|
params=sampling_params,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
prompt_adapter_request=prompt_adapter_request,
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
@ -533,9 +530,9 @@ class LLM:
|
|||||||
conversation, mm_data = parse_chat_messages(messages, model_config,
|
conversation, mm_data = parse_chat_messages(messages, model_config,
|
||||||
tokenizer)
|
tokenizer)
|
||||||
|
|
||||||
prompt_data: Union[str, List[int]]
|
prompt: Union[str, List[int]]
|
||||||
if isinstance(tokenizer, MistralTokenizer):
|
if isinstance(tokenizer, MistralTokenizer):
|
||||||
prompt_data = apply_mistral_chat_template(
|
prompt = apply_mistral_chat_template(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
chat_template=chat_template,
|
chat_template=chat_template,
|
||||||
@ -543,7 +540,7 @@ class LLM:
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt_data = apply_hf_chat_template(
|
prompt = apply_hf_chat_template(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
chat_template=chat_template,
|
chat_template=chat_template,
|
||||||
@ -551,17 +548,17 @@ class LLM:
|
|||||||
tools=tools,
|
tools=tools,
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt: PromptType
|
inputs: PromptInputs
|
||||||
if is_list_of(prompt_data, int):
|
if is_list_of(prompt, int):
|
||||||
prompt = TokensPrompt(prompt_token_ids=prompt_data)
|
inputs = TokensPrompt(prompt_token_ids=prompt)
|
||||||
else:
|
else:
|
||||||
prompt = TextPrompt(prompt=prompt_data)
|
inputs = TextPrompt(prompt=prompt)
|
||||||
|
|
||||||
if mm_data is not None:
|
if mm_data is not None:
|
||||||
prompt["multi_modal_data"] = mm_data
|
inputs["multi_modal_data"] = mm_data
|
||||||
|
|
||||||
return self.generate(
|
return self.generate(
|
||||||
prompt,
|
inputs,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
use_tqdm=use_tqdm,
|
use_tqdm=use_tqdm,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
@ -631,8 +628,8 @@ class LLM:
|
|||||||
@overload
|
@overload
|
||||||
def encode(
|
def encode(
|
||||||
self,
|
self,
|
||||||
prompts: Union[PromptType, Sequence[PromptType]],
|
inputs: Union[PromptInputs, Sequence[PromptInputs]],
|
||||||
/,
|
/, # We may enable `inputs` keyword after removing the old API
|
||||||
*,
|
*,
|
||||||
pooling_params: Optional[Union[PoolingParams,
|
pooling_params: Optional[Union[PoolingParams,
|
||||||
Sequence[PoolingParams]]] = None,
|
Sequence[PoolingParams]]] = None,
|
||||||
@ -649,7 +646,7 @@ class LLM:
|
|||||||
)
|
)
|
||||||
def encode(
|
def encode(
|
||||||
self,
|
self,
|
||||||
prompts: Union[Union[PromptType, Sequence[PromptType]],
|
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
|
||||||
Optional[Union[str, List[str]]]] = None,
|
Optional[Union[str, List[str]]]] = None,
|
||||||
pooling_params: Optional[Union[PoolingParams,
|
pooling_params: Optional[Union[PoolingParams,
|
||||||
Sequence[PoolingParams]]] = None,
|
Sequence[PoolingParams]]] = None,
|
||||||
@ -665,9 +662,9 @@ class LLM:
|
|||||||
into a single list and pass it to this method.
|
into a single list and pass it to this method.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompts: The prompts to the LLM. You may pass a sequence of prompts
|
inputs: The inputs to the LLM. You may pass a sequence of inputs for
|
||||||
for batch inference. See :class:`~vllm.inputs.PromptType`
|
batch inference. See :class:`~vllm.inputs.PromptInputs`
|
||||||
for more details about the format of each prompts.
|
for more details about the format of each input.
|
||||||
pooling_params: The pooling parameters for pooling. If None, we
|
pooling_params: The pooling parameters for pooling. If None, we
|
||||||
use the default pooling parameters.
|
use the default pooling parameters.
|
||||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||||
@ -690,20 +687,19 @@ class LLM:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if prompt_token_ids is not None:
|
if prompt_token_ids is not None:
|
||||||
parsed_prompts = self._convert_v1_inputs(
|
inputs = self._convert_v1_inputs(
|
||||||
prompts=cast(Optional[Union[str, List[str]]], prompts),
|
prompts=cast(Optional[Union[str, List[str]]], prompts),
|
||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
|
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
|
||||||
prompts)
|
|
||||||
|
|
||||||
if pooling_params is None:
|
if pooling_params is None:
|
||||||
# Use default pooling params.
|
# Use default pooling params.
|
||||||
pooling_params = PoolingParams()
|
pooling_params = PoolingParams()
|
||||||
|
|
||||||
self._validate_and_add_requests(
|
self._validate_and_add_requests(
|
||||||
prompts=parsed_prompts,
|
inputs=inputs,
|
||||||
params=pooling_params,
|
params=pooling_params,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
prompt_adapter_request=prompt_adapter_request,
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
@ -747,9 +743,9 @@ class LLM:
|
|||||||
raise ValueError("Either prompts or prompt_token_ids must be "
|
raise ValueError("Either prompts or prompt_token_ids must be "
|
||||||
"provided.")
|
"provided.")
|
||||||
|
|
||||||
parsed_prompts: List[PromptType] = []
|
inputs: List[PromptInputs] = []
|
||||||
for i in range(num_requests):
|
for i in range(num_requests):
|
||||||
item: PromptType
|
item: PromptInputs
|
||||||
|
|
||||||
if prompts is not None:
|
if prompts is not None:
|
||||||
item = TextPrompt(prompt=prompts[i])
|
item = TextPrompt(prompt=prompts[i])
|
||||||
@ -758,24 +754,24 @@ class LLM:
|
|||||||
else:
|
else:
|
||||||
raise AssertionError
|
raise AssertionError
|
||||||
|
|
||||||
parsed_prompts.append(item)
|
inputs.append(item)
|
||||||
|
|
||||||
return parsed_prompts
|
return inputs
|
||||||
|
|
||||||
def _validate_and_add_requests(
|
def _validate_and_add_requests(
|
||||||
self,
|
self,
|
||||||
prompts: Union[PromptType, Sequence[PromptType]],
|
inputs: Union[PromptInputs, Sequence[PromptInputs]],
|
||||||
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
|
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
|
||||||
Sequence[PoolingParams]],
|
Sequence[PoolingParams]],
|
||||||
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
|
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||||
guided_options: Optional[GuidedDecodingRequest] = None,
|
guided_options: Optional[GuidedDecodingRequest] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(prompts, (str, dict)):
|
if isinstance(inputs, (str, dict)):
|
||||||
# Convert a single prompt to a list.
|
# Convert a single prompt to a list.
|
||||||
prompts = [prompts]
|
inputs = [inputs]
|
||||||
|
|
||||||
num_requests = len(prompts)
|
num_requests = len(inputs)
|
||||||
if isinstance(params, list) and len(params) != num_requests:
|
if isinstance(params, list) and len(params) != num_requests:
|
||||||
raise ValueError("The lengths of prompts and params "
|
raise ValueError("The lengths of prompts and params "
|
||||||
"must be the same.")
|
"must be the same.")
|
||||||
@ -792,9 +788,9 @@ class LLM:
|
|||||||
sp.output_kind = RequestOutputKind.FINAL_ONLY
|
sp.output_kind = RequestOutputKind.FINAL_ONLY
|
||||||
|
|
||||||
# Add requests to the engine.
|
# Add requests to the engine.
|
||||||
for i, prompt in enumerate(prompts):
|
for i, request_inputs in enumerate(inputs):
|
||||||
self._add_request(
|
self._add_request(
|
||||||
prompt,
|
request_inputs,
|
||||||
params[i] if isinstance(params, Sequence) else params,
|
params[i] if isinstance(params, Sequence) else params,
|
||||||
lora_request=lora_request[i] if isinstance(
|
lora_request=lora_request[i] if isinstance(
|
||||||
lora_request, Sequence) else lora_request,
|
lora_request, Sequence) else lora_request,
|
||||||
@ -803,7 +799,7 @@ class LLM:
|
|||||||
|
|
||||||
def _add_request(
|
def _add_request(
|
||||||
self,
|
self,
|
||||||
prompt: PromptType,
|
inputs: PromptInputs,
|
||||||
params: Union[SamplingParams, PoolingParams],
|
params: Union[SamplingParams, PoolingParams],
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
@ -811,7 +807,7 @@ class LLM:
|
|||||||
request_id = str(next(self.request_counter))
|
request_id = str(next(self.request_counter))
|
||||||
self.llm_engine.add_request(
|
self.llm_engine.add_request(
|
||||||
request_id,
|
request_id,
|
||||||
prompt,
|
inputs,
|
||||||
params,
|
params,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
prompt_adapter_request=prompt_adapter_request,
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
|
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
|
||||||
LLMInputs, PromptType, SingletonPrompt, TextPrompt,
|
LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt,
|
||||||
TokensPrompt, build_explicit_enc_dec_prompt,
|
TokensPrompt, build_explicit_enc_dec_prompt,
|
||||||
to_enc_dec_tuple_list, zip_enc_dec_prompts)
|
to_enc_dec_tuple_list, zip_enc_dec_prompts)
|
||||||
from .registry import InputContext, InputRegistry
|
from .registry import InputContext, InputRegistry
|
||||||
@ -16,8 +16,8 @@ See also:
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"TextPrompt",
|
"TextPrompt",
|
||||||
"TokensPrompt",
|
"TokensPrompt",
|
||||||
"PromptType",
|
"PromptInputs",
|
||||||
"SingletonPrompt",
|
"SingletonPromptInputs",
|
||||||
"ExplicitEncoderDecoderPrompt",
|
"ExplicitEncoderDecoderPrompt",
|
||||||
"LLMInputs",
|
"LLMInputs",
|
||||||
"EncoderDecoderLLMInputs",
|
"EncoderDecoderLLMInputs",
|
||||||
|
@ -33,7 +33,7 @@ class TokensPrompt(TypedDict):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
SingletonPrompt = Union[str, TextPrompt, TokensPrompt]
|
SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt]
|
||||||
"""
|
"""
|
||||||
Set of possible schemas for a single LLM input:
|
Set of possible schemas for a single LLM input:
|
||||||
|
|
||||||
@ -46,7 +46,7 @@ which may be utilized for encoder/decoder models when
|
|||||||
the user desires to express both the encoder & decoder
|
the user desires to express both the encoder & decoder
|
||||||
prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt`
|
prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt`
|
||||||
|
|
||||||
A prompt of type :class:`SingletonPromptType` may be employed
|
A prompt of type :class:`SingletonPromptInputs` may be employed
|
||||||
as (1) input to a decoder-only model, (2) input to
|
as (1) input to a decoder-only model, (2) input to
|
||||||
the encoder of an encoder/decoder model, in the scenario
|
the encoder of an encoder/decoder model, in the scenario
|
||||||
where the decoder-prompt is not specified explicitly, or
|
where the decoder-prompt is not specified explicitly, or
|
||||||
@ -55,12 +55,12 @@ more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt`
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
_T1_co = TypeVar("_T1_co",
|
_T1_co = TypeVar("_T1_co",
|
||||||
bound=SingletonPrompt,
|
bound=SingletonPromptInputs,
|
||||||
default=SingletonPrompt,
|
default=SingletonPromptInputs,
|
||||||
covariant=True)
|
covariant=True)
|
||||||
_T2_co = TypeVar("_T2_co",
|
_T2_co = TypeVar("_T2_co",
|
||||||
bound=SingletonPrompt,
|
bound=SingletonPromptInputs,
|
||||||
default=SingletonPrompt,
|
default=SingletonPromptInputs,
|
||||||
covariant=True)
|
covariant=True)
|
||||||
|
|
||||||
|
|
||||||
@ -72,7 +72,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
|
|||||||
|
|
||||||
The encoder and decoder prompts, respectively,
|
The encoder and decoder prompts, respectively,
|
||||||
may formatted according to any of the
|
may formatted according to any of the
|
||||||
:class:`SingletonPromptType` schemas, and are not
|
:class:`SingletonPromptInputs` schemas, and are not
|
||||||
required to have the same schema.
|
required to have the same schema.
|
||||||
|
|
||||||
Only the encoder prompt may have multi-modal data.
|
Only the encoder prompt may have multi-modal data.
|
||||||
@ -81,7 +81,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
|
|||||||
be used as an input to a decoder-only model,
|
be used as an input to a decoder-only model,
|
||||||
and that the `encoder_prompt` and `decoder_prompt`
|
and that the `encoder_prompt` and `decoder_prompt`
|
||||||
fields of this data structure themselves must be
|
fields of this data structure themselves must be
|
||||||
:class:`SingletonPromptType` instances.
|
:class:`SingletonPromptInputs` instances.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
encoder_prompt: _T1_co
|
encoder_prompt: _T1_co
|
||||||
@ -89,7 +89,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
|
|||||||
decoder_prompt: Optional[_T2_co]
|
decoder_prompt: Optional[_T2_co]
|
||||||
|
|
||||||
|
|
||||||
PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt]
|
PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt]
|
||||||
"""
|
"""
|
||||||
Set of possible schemas for an LLM input, including
|
Set of possible schemas for an LLM input, including
|
||||||
both decoder-only and encoder/decoder input types:
|
both decoder-only and encoder/decoder input types:
|
||||||
@ -140,8 +140,12 @@ class EncoderDecoderLLMInputs(LLMInputs):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt)
|
_T1 = TypeVar("_T1",
|
||||||
_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)
|
bound=SingletonPromptInputs,
|
||||||
|
default=SingletonPromptInputs)
|
||||||
|
_T2 = TypeVar("_T2",
|
||||||
|
bound=SingletonPromptInputs,
|
||||||
|
default=SingletonPromptInputs)
|
||||||
|
|
||||||
|
|
||||||
def build_explicit_enc_dec_prompt(
|
def build_explicit_enc_dec_prompt(
|
||||||
|
@ -5,7 +5,7 @@ from typing_extensions import TypeIs
|
|||||||
from vllm.utils import is_list_of
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
|
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
|
||||||
LLMInputs, PromptType, SingletonPrompt, TextPrompt,
|
LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt,
|
||||||
TokensPrompt)
|
TokensPrompt)
|
||||||
|
|
||||||
|
|
||||||
@ -81,23 +81,23 @@ class ParsedTokensPrompt(TypedDict):
|
|||||||
|
|
||||||
|
|
||||||
def parse_singleton_prompt(
|
def parse_singleton_prompt(
|
||||||
prompt: SingletonPrompt,
|
inputs: SingletonPromptInputs,
|
||||||
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]:
|
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]:
|
||||||
if isinstance(prompt, str):
|
if isinstance(inputs, str):
|
||||||
return ParsedStrPrompt(type="str", content=prompt)
|
return ParsedStrPrompt(type="str", content=inputs)
|
||||||
elif isinstance(prompt, dict):
|
elif isinstance(inputs, dict):
|
||||||
if "prompt_token_ids" in prompt:
|
if "prompt_token_ids" in inputs:
|
||||||
return ParsedTokensPrompt(type="tokens",
|
return ParsedTokensPrompt(type="tokens",
|
||||||
content=prompt) # type: ignore
|
content=inputs) # type: ignore
|
||||||
elif "prompt" in prompt:
|
elif "prompt" in inputs:
|
||||||
return ParsedTextPrompt(type="text", content=prompt)
|
return ParsedTextPrompt(type="text", content=inputs)
|
||||||
|
|
||||||
raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt")
|
raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt")
|
||||||
|
|
||||||
|
|
||||||
def is_explicit_encoder_decoder_prompt(
|
def is_explicit_encoder_decoder_prompt(
|
||||||
prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
|
inputs: PromptInputs) -> TypeIs[ExplicitEncoderDecoderPrompt]:
|
||||||
return isinstance(prompt, dict) and "encoder_prompt" in prompt
|
return isinstance(inputs, dict) and "encoder_prompt" in inputs
|
||||||
|
|
||||||
|
|
||||||
def is_valid_encoder_decoder_llm_inputs(
|
def is_valid_encoder_decoder_llm_inputs(
|
||||||
|
@ -9,8 +9,8 @@ from vllm.lora.request import LoRARequest
|
|||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||||
|
|
||||||
from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptType,
|
from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
|
||||||
SingletonPrompt)
|
SingletonPromptInputs)
|
||||||
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
|
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -206,7 +206,7 @@ class InputPreprocessor:
|
|||||||
|
|
||||||
def _extract_prompt_components(
|
def _extract_prompt_components(
|
||||||
self,
|
self,
|
||||||
prompt: SingletonPrompt,
|
inputs: SingletonPromptInputs,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
) -> PromptComponents:
|
) -> PromptComponents:
|
||||||
@ -216,7 +216,7 @@ class InputPreprocessor:
|
|||||||
Arguments:
|
Arguments:
|
||||||
|
|
||||||
* request_id
|
* request_id
|
||||||
* prompt: single encoder or decoder input prompt
|
* inputs: single encoder or decoder input prompt
|
||||||
* lora_request: this is only valid for decoder prompts
|
* lora_request: this is only valid for decoder prompts
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -226,24 +226,24 @@ class InputPreprocessor:
|
|||||||
* multi_modal_data
|
* multi_modal_data
|
||||||
'''
|
'''
|
||||||
|
|
||||||
parsed = parse_singleton_prompt(prompt)
|
parsed = parse_singleton_prompt(inputs)
|
||||||
|
|
||||||
if parsed["type"] == "str":
|
if parsed["type"] == "str":
|
||||||
prompt_text = parsed["content"]
|
prompt = parsed["content"]
|
||||||
prompt_token_ids = self._tokenize_prompt(
|
prompt_token_ids = self._tokenize_prompt(
|
||||||
prompt_text,
|
prompt,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
)
|
)
|
||||||
multi_modal_data = None
|
multi_modal_data = None
|
||||||
elif parsed["type"] == "tokens":
|
elif parsed["type"] == "tokens":
|
||||||
prompt_text = None
|
prompt = None
|
||||||
prompt_token_ids = parsed["content"]["prompt_token_ids"]
|
prompt_token_ids = parsed["content"]["prompt_token_ids"]
|
||||||
multi_modal_data = parsed["content"].get("multi_modal_data")
|
multi_modal_data = parsed["content"].get("multi_modal_data")
|
||||||
elif parsed["type"] == "text":
|
elif parsed["type"] == "text":
|
||||||
prompt_text = parsed["content"]["prompt"]
|
prompt = parsed["content"]["prompt"]
|
||||||
prompt_token_ids = self._tokenize_prompt(
|
prompt_token_ids = self._tokenize_prompt(
|
||||||
prompt_text,
|
prompt,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
)
|
)
|
||||||
@ -251,33 +251,33 @@ class InputPreprocessor:
|
|||||||
else:
|
else:
|
||||||
assert_never(parsed)
|
assert_never(parsed)
|
||||||
|
|
||||||
return prompt_text, prompt_token_ids, multi_modal_data
|
return prompt, prompt_token_ids, multi_modal_data
|
||||||
|
|
||||||
async def _extract_prompt_components_async(
|
async def _extract_prompt_components_async(
|
||||||
self,
|
self,
|
||||||
prompt: SingletonPrompt,
|
inputs: SingletonPromptInputs,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
) -> PromptComponents:
|
) -> PromptComponents:
|
||||||
"""Async version of :meth:`_extract_prompt_components`."""
|
"""Async version of :meth:`_extract_prompt_components`."""
|
||||||
parsed = parse_singleton_prompt(prompt)
|
parsed = parse_singleton_prompt(inputs)
|
||||||
|
|
||||||
if parsed["type"] == "str":
|
if parsed["type"] == "str":
|
||||||
prompt_text = parsed["content"]
|
prompt = parsed["content"]
|
||||||
prompt_token_ids = await self._tokenize_prompt_async(
|
prompt_token_ids = await self._tokenize_prompt_async(
|
||||||
prompt_text,
|
prompt,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
)
|
)
|
||||||
multi_modal_data = None
|
multi_modal_data = None
|
||||||
elif parsed["type"] == "tokens":
|
elif parsed["type"] == "tokens":
|
||||||
prompt_text = None
|
prompt = None
|
||||||
prompt_token_ids = parsed["content"]["prompt_token_ids"]
|
prompt_token_ids = parsed["content"]["prompt_token_ids"]
|
||||||
multi_modal_data = parsed["content"].get("multi_modal_data")
|
multi_modal_data = parsed["content"].get("multi_modal_data")
|
||||||
elif parsed["type"] == "text":
|
elif parsed["type"] == "text":
|
||||||
prompt_text = parsed["content"]["prompt"]
|
prompt = parsed["content"]["prompt"]
|
||||||
prompt_token_ids = await self._tokenize_prompt_async(
|
prompt_token_ids = await self._tokenize_prompt_async(
|
||||||
prompt_text,
|
prompt,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
)
|
)
|
||||||
@ -285,7 +285,7 @@ class InputPreprocessor:
|
|||||||
else:
|
else:
|
||||||
assert_never(parsed)
|
assert_never(parsed)
|
||||||
|
|
||||||
return prompt_text, prompt_token_ids, multi_modal_data
|
return prompt, prompt_token_ids, multi_modal_data
|
||||||
|
|
||||||
def _build_enc_dec_llm_inputs(
|
def _build_enc_dec_llm_inputs(
|
||||||
self,
|
self,
|
||||||
@ -311,7 +311,7 @@ class InputPreprocessor:
|
|||||||
|
|
||||||
def _process_encoder_decoder_prompt(
|
def _process_encoder_decoder_prompt(
|
||||||
self,
|
self,
|
||||||
prompt: PromptType,
|
inputs: PromptInputs,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
) -> EncoderDecoderLLMInputs:
|
) -> EncoderDecoderLLMInputs:
|
||||||
'''
|
'''
|
||||||
@ -339,7 +339,7 @@ class InputPreprocessor:
|
|||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
|
||||||
* prompt: an input prompt
|
* inputs: an input prompt
|
||||||
* request_id
|
* request_id
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -350,13 +350,13 @@ class InputPreprocessor:
|
|||||||
encoder_comps: PromptComponents
|
encoder_comps: PromptComponents
|
||||||
decoder_comps: DecoderPromptComponents
|
decoder_comps: DecoderPromptComponents
|
||||||
|
|
||||||
if is_explicit_encoder_decoder_prompt(prompt):
|
if is_explicit_encoder_decoder_prompt(inputs):
|
||||||
encoder_comps = self._extract_prompt_components(
|
encoder_comps = self._extract_prompt_components(
|
||||||
prompt["encoder_prompt"],
|
inputs["encoder_prompt"],
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (decoder_input := prompt["decoder_prompt"]) is None:
|
if (decoder_input := inputs["decoder_prompt"]) is None:
|
||||||
decoder_comps = None, None, None
|
decoder_comps = None, None, None
|
||||||
else:
|
else:
|
||||||
decoder_comps = self._extract_prompt_components(
|
decoder_comps = self._extract_prompt_components(
|
||||||
@ -365,7 +365,7 @@ class InputPreprocessor:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
encoder_comps = self._extract_prompt_components(
|
encoder_comps = self._extract_prompt_components(
|
||||||
prompt,
|
inputs,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -375,20 +375,20 @@ class InputPreprocessor:
|
|||||||
|
|
||||||
async def _process_encoder_decoder_prompt_async(
|
async def _process_encoder_decoder_prompt_async(
|
||||||
self,
|
self,
|
||||||
prompt: PromptType,
|
inputs: PromptInputs,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
) -> EncoderDecoderLLMInputs:
|
) -> EncoderDecoderLLMInputs:
|
||||||
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
|
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
|
||||||
encoder_comps: PromptComponents
|
encoder_comps: PromptComponents
|
||||||
decoder_comps: DecoderPromptComponents
|
decoder_comps: DecoderPromptComponents
|
||||||
|
|
||||||
if is_explicit_encoder_decoder_prompt(prompt):
|
if is_explicit_encoder_decoder_prompt(inputs):
|
||||||
encoder_task = self._extract_prompt_components_async(
|
encoder_task = self._extract_prompt_components_async(
|
||||||
prompt["encoder_prompt"],
|
inputs["encoder_prompt"],
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (decoder_input := prompt["decoder_prompt"]) is None:
|
if (decoder_input := inputs["decoder_prompt"]) is None:
|
||||||
encoder_comps = await encoder_task
|
encoder_comps = await encoder_task
|
||||||
decoder_comps = None, None, None
|
decoder_comps = None, None, None
|
||||||
else:
|
else:
|
||||||
@ -401,7 +401,7 @@ class InputPreprocessor:
|
|||||||
encoder_task, decoder_task)
|
encoder_task, decoder_task)
|
||||||
else:
|
else:
|
||||||
encoder_comps = await self._extract_prompt_components_async(
|
encoder_comps = await self._extract_prompt_components_async(
|
||||||
prompt,
|
inputs,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -425,7 +425,7 @@ class InputPreprocessor:
|
|||||||
|
|
||||||
def _process_decoder_only_prompt(
|
def _process_decoder_only_prompt(
|
||||||
self,
|
self,
|
||||||
prompt: SingletonPrompt,
|
inputs: SingletonPromptInputs,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
@ -436,7 +436,7 @@ class InputPreprocessor:
|
|||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
|
||||||
* prompt: input prompt
|
* inputs: input prompt
|
||||||
* request_id
|
* request_id
|
||||||
* lora_request
|
* lora_request
|
||||||
* prompt_adapter_request
|
* prompt_adapter_request
|
||||||
@ -447,7 +447,7 @@ class InputPreprocessor:
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
prompt_comps = self._extract_prompt_components(
|
prompt_comps = self._extract_prompt_components(
|
||||||
prompt,
|
inputs,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
)
|
)
|
||||||
@ -459,14 +459,14 @@ class InputPreprocessor:
|
|||||||
|
|
||||||
async def _process_decoder_only_prompt_async(
|
async def _process_decoder_only_prompt_async(
|
||||||
self,
|
self,
|
||||||
prompt: SingletonPrompt,
|
inputs: SingletonPromptInputs,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
) -> LLMInputs:
|
) -> LLMInputs:
|
||||||
"""Async version of :meth:`_process_decoder_only_prompt`."""
|
"""Async version of :meth:`_process_decoder_only_prompt`."""
|
||||||
prompt_comps = await self._extract_prompt_components_async(
|
prompt_comps = await self._extract_prompt_components_async(
|
||||||
prompt,
|
inputs,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
)
|
)
|
||||||
@ -478,7 +478,7 @@ class InputPreprocessor:
|
|||||||
|
|
||||||
def preprocess(
|
def preprocess(
|
||||||
self,
|
self,
|
||||||
prompt: PromptType,
|
inputs: PromptInputs,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
@ -488,17 +488,17 @@ class InputPreprocessor:
|
|||||||
# Encoder-decoder model requires special mapping of
|
# Encoder-decoder model requires special mapping of
|
||||||
# input prompts to encoder & decoder
|
# input prompts to encoder & decoder
|
||||||
return self._process_encoder_decoder_prompt(
|
return self._process_encoder_decoder_prompt(
|
||||||
prompt,
|
inputs,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_explicit_encoder_decoder_prompt(prompt):
|
if is_explicit_encoder_decoder_prompt(inputs):
|
||||||
raise ValueError("Cannot pass encoder-decoder prompt "
|
raise ValueError("Cannot pass encoder-decoder prompt "
|
||||||
"to decoder-only models")
|
"to decoder-only models")
|
||||||
|
|
||||||
# Decoder-only operation
|
# Decoder-only operation
|
||||||
return self._process_decoder_only_prompt(
|
return self._process_decoder_only_prompt(
|
||||||
prompt,
|
inputs,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
prompt_adapter_request=prompt_adapter_request,
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
@ -506,7 +506,7 @@ class InputPreprocessor:
|
|||||||
|
|
||||||
async def preprocess_async(
|
async def preprocess_async(
|
||||||
self,
|
self,
|
||||||
prompt: PromptType,
|
inputs: PromptInputs,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
@ -516,17 +516,17 @@ class InputPreprocessor:
|
|||||||
# Encoder-decoder model requires special mapping of
|
# Encoder-decoder model requires special mapping of
|
||||||
# input prompts to encoder & decoder
|
# input prompts to encoder & decoder
|
||||||
return await self._process_encoder_decoder_prompt_async(
|
return await self._process_encoder_decoder_prompt_async(
|
||||||
prompt,
|
inputs,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_explicit_encoder_decoder_prompt(prompt):
|
if is_explicit_encoder_decoder_prompt(inputs):
|
||||||
raise ValueError("Cannot pass encoder-decoder prompt "
|
raise ValueError("Cannot pass encoder-decoder prompt "
|
||||||
"to decoder-only models")
|
"to decoder-only models")
|
||||||
|
|
||||||
# Decoder-only operation
|
# Decoder-only operation
|
||||||
return await self._process_decoder_only_prompt_async(
|
return await self._process_decoder_only_prompt_async(
|
||||||
prompt,
|
inputs,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
prompt_adapter_request=prompt_adapter_request,
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user