rename PromptInputs and inputs with backward compatibility (#8760)
This commit is contained in:
parent
0c4d2ad5e6
commit
28e1299e60
@ -11,7 +11,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs
|
from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs
|
||||||
from vllm.inputs import PromptInputs
|
from vllm.inputs import PromptType
|
||||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
@ -61,7 +61,7 @@ def main(args: argparse.Namespace):
|
|||||||
dummy_prompt_token_ids = np.random.randint(10000,
|
dummy_prompt_token_ids = np.random.randint(10000,
|
||||||
size=(args.batch_size,
|
size=(args.batch_size,
|
||||||
args.input_len))
|
args.input_len))
|
||||||
dummy_inputs: List[PromptInputs] = [{
|
dummy_prompts: List[PromptType] = [{
|
||||||
"prompt_token_ids": batch
|
"prompt_token_ids": batch
|
||||||
} for batch in dummy_prompt_token_ids.tolist()]
|
} for batch in dummy_prompt_token_ids.tolist()]
|
||||||
|
|
||||||
@ -74,13 +74,13 @@ def main(args: argparse.Namespace):
|
|||||||
],
|
],
|
||||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||||
str(profile_dir))) as p:
|
str(profile_dir))) as p:
|
||||||
llm.generate(dummy_inputs,
|
llm.generate(dummy_prompts,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
use_tqdm=False)
|
use_tqdm=False)
|
||||||
print(p.key_averages())
|
print(p.key_averages())
|
||||||
else:
|
else:
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
llm.generate(dummy_inputs,
|
llm.generate(dummy_prompts,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
use_tqdm=False)
|
use_tqdm=False)
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
|
@ -8,7 +8,7 @@ Multi-Modality
|
|||||||
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
|
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
|
||||||
|
|
||||||
Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
|
Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
|
||||||
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`.
|
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptType`.
|
||||||
|
|
||||||
Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
|
Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
|
||||||
by following :ref:`this guide <adding_multimodal_plugin>`.
|
by following :ref:`this guide <adding_multimodal_plugin>`.
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
LLM Inputs
|
LLM Inputs
|
||||||
==========
|
==========
|
||||||
|
|
||||||
.. autodata:: vllm.inputs.PromptInputs
|
.. autodata:: vllm.inputs.PromptType
|
||||||
|
|
||||||
.. autoclass:: vllm.inputs.TextPrompt
|
.. autoclass:: vllm.inputs.TextPrompt
|
||||||
:show-inheritance:
|
:show-inheritance:
|
||||||
|
@ -27,7 +27,7 @@ The :class:`~vllm.LLM` class can be instantiated in much the same way as languag
|
|||||||
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
|
We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow
|
||||||
the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model.
|
the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model.
|
||||||
|
|
||||||
To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`:
|
To pass an image to the model, note the following in :class:`vllm.inputs.PromptType`:
|
||||||
|
|
||||||
* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
|
* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
|
||||||
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.
|
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.
|
||||||
|
@ -86,17 +86,19 @@ class MockAsyncLLMEngine(AsyncLLMEngine):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_new_requests_event():
|
async def test_new_requests_event():
|
||||||
|
params = SamplingParams()
|
||||||
|
|
||||||
engine = MockAsyncLLMEngine()
|
engine = MockAsyncLLMEngine()
|
||||||
engine.start_background_loop()
|
engine.start_background_loop()
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
assert engine.engine.step_calls == 0
|
assert engine.engine.step_calls == 0
|
||||||
|
|
||||||
await engine.add_request("1", "", None)
|
await engine.add_request("1", "", params)
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
assert engine.engine.add_request_calls == 1
|
assert engine.engine.add_request_calls == 1
|
||||||
assert engine.engine.step_calls == 1
|
assert engine.engine.step_calls == 1
|
||||||
|
|
||||||
await engine.add_request("2", "", None)
|
await engine.add_request("2", "", params)
|
||||||
engine.engine.generate("2")
|
engine.engine.generate("2")
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
@ -111,7 +113,7 @@ async def test_new_requests_event():
|
|||||||
await asyncio.sleep(0.001)
|
await asyncio.sleep(0.001)
|
||||||
assert engine.engine.step_calls == old_step_calls
|
assert engine.engine.step_calls == old_step_calls
|
||||||
|
|
||||||
await engine.add_request("3", "", None)
|
await engine.add_request("3", "", params)
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
assert engine.engine.add_request_calls == 3
|
assert engine.engine.add_request_calls == 3
|
||||||
assert engine.engine.step_calls == old_step_calls + 1
|
assert engine.engine.step_calls == old_step_calls + 1
|
||||||
|
@ -49,21 +49,6 @@ def assert_outputs_equal(o1: List[EmbeddingRequestOutput],
|
|||||||
assert [o.outputs for o in o1] == [o.outputs for o in o2]
|
assert [o.outputs for o in o1] == [o.outputs for o in o2]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
|
||||||
@pytest.mark.parametrize('prompt', PROMPTS)
|
|
||||||
def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
|
|
||||||
pooling_params = PoolingParams()
|
|
||||||
|
|
||||||
with pytest.warns(DeprecationWarning, match="'prompts'"):
|
|
||||||
v1_output = llm.encode(prompts=prompt, pooling_params=pooling_params)
|
|
||||||
|
|
||||||
v2_output = llm.encode(prompt, pooling_params=pooling_params)
|
|
||||||
assert_outputs_equal(v1_output, v2_output)
|
|
||||||
|
|
||||||
v2_output = llm.encode({"prompt": prompt}, pooling_params=pooling_params)
|
|
||||||
assert_outputs_equal(v1_output, v2_output)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
|
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
|
||||||
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
|
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
|
||||||
@ -79,25 +64,6 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
|
|||||||
assert_outputs_equal(v1_output, v2_output)
|
assert_outputs_equal(v1_output, v2_output)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
|
||||||
def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
|
|
||||||
pooling_params = PoolingParams()
|
|
||||||
|
|
||||||
with pytest.warns(DeprecationWarning, match="'prompts'"):
|
|
||||||
v1_output = llm.encode(prompts=PROMPTS, pooling_params=pooling_params)
|
|
||||||
|
|
||||||
v2_output = llm.encode(PROMPTS, pooling_params=pooling_params)
|
|
||||||
assert_outputs_equal(v1_output, v2_output)
|
|
||||||
|
|
||||||
v2_output = llm.encode(
|
|
||||||
[{
|
|
||||||
"prompt": p
|
|
||||||
} for p in PROMPTS],
|
|
||||||
pooling_params=pooling_params,
|
|
||||||
)
|
|
||||||
assert_outputs_equal(v1_output, v2_output)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
|
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
|
||||||
pooling_params = PoolingParams()
|
pooling_params = PoolingParams()
|
||||||
|
@ -47,23 +47,6 @@ def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]):
|
|||||||
assert [o.outputs for o in o1] == [o.outputs for o in o2]
|
assert [o.outputs for o in o1] == [o.outputs for o in o2]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
|
||||||
@pytest.mark.parametrize('prompt', PROMPTS)
|
|
||||||
def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt):
|
|
||||||
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
|
|
||||||
|
|
||||||
with pytest.warns(DeprecationWarning, match="'prompts'"):
|
|
||||||
v1_output = llm.generate(prompts=prompt,
|
|
||||||
sampling_params=sampling_params)
|
|
||||||
|
|
||||||
v2_output = llm.generate(prompt, sampling_params=sampling_params)
|
|
||||||
assert_outputs_equal(v1_output, v2_output)
|
|
||||||
|
|
||||||
v2_output = llm.generate({"prompt": prompt},
|
|
||||||
sampling_params=sampling_params)
|
|
||||||
assert_outputs_equal(v1_output, v2_output)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
|
@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS)
|
||||||
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
|
def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
|
||||||
@ -79,26 +62,6 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM,
|
|||||||
assert_outputs_equal(v1_output, v2_output)
|
assert_outputs_equal(v1_output, v2_output)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
|
||||||
def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM):
|
|
||||||
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
|
|
||||||
|
|
||||||
with pytest.warns(DeprecationWarning, match="'prompts'"):
|
|
||||||
v1_output = llm.generate(prompts=PROMPTS,
|
|
||||||
sampling_params=sampling_params)
|
|
||||||
|
|
||||||
v2_output = llm.generate(PROMPTS, sampling_params=sampling_params)
|
|
||||||
assert_outputs_equal(v1_output, v2_output)
|
|
||||||
|
|
||||||
v2_output = llm.generate(
|
|
||||||
[{
|
|
||||||
"prompt": p
|
|
||||||
} for p in PROMPTS],
|
|
||||||
sampling_params=sampling_params,
|
|
||||||
)
|
|
||||||
assert_outputs_equal(v1_output, v2_output)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
|
def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM):
|
||||||
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
|
sampling_params = SamplingParams(temperature=0.0, top_p=1.0)
|
||||||
|
@ -61,7 +61,7 @@ async def test_evil_forward(tmp_socket):
|
|||||||
|
|
||||||
# Throws an error in first forward pass.
|
# Throws an error in first forward pass.
|
||||||
with pytest.raises(RAISED_ERROR):
|
with pytest.raises(RAISED_ERROR):
|
||||||
async for _ in client.generate(inputs="Hello my name is",
|
async for _ in client.generate(prompt="Hello my name is",
|
||||||
sampling_params=SamplingParams(),
|
sampling_params=SamplingParams(),
|
||||||
request_id=uuid.uuid4()):
|
request_id=uuid.uuid4()):
|
||||||
pass
|
pass
|
||||||
@ -69,7 +69,7 @@ async def test_evil_forward(tmp_socket):
|
|||||||
|
|
||||||
# Engine is errored, should get ENGINE_DEAD_ERROR.
|
# Engine is errored, should get ENGINE_DEAD_ERROR.
|
||||||
with pytest.raises(MQEngineDeadError):
|
with pytest.raises(MQEngineDeadError):
|
||||||
async for _ in client.generate(inputs="Hello my name is",
|
async for _ in client.generate(prompt="Hello my name is",
|
||||||
sampling_params=SamplingParams(),
|
sampling_params=SamplingParams(),
|
||||||
request_id=uuid.uuid4()):
|
request_id=uuid.uuid4()):
|
||||||
pass
|
pass
|
||||||
@ -118,7 +118,7 @@ async def test_failed_health_check(tmp_socket):
|
|||||||
|
|
||||||
# Generate call should throw ENGINE_DEAD_ERROR
|
# Generate call should throw ENGINE_DEAD_ERROR
|
||||||
with pytest.raises(MQEngineDeadError):
|
with pytest.raises(MQEngineDeadError):
|
||||||
async for _ in client.generate(inputs="Hello my name is",
|
async for _ in client.generate(prompt="Hello my name is",
|
||||||
sampling_params=SamplingParams(),
|
sampling_params=SamplingParams(),
|
||||||
request_id=uuid.uuid4()):
|
request_id=uuid.uuid4()):
|
||||||
pass
|
pass
|
||||||
@ -160,7 +160,7 @@ async def test_failed_abort(tmp_socket):
|
|||||||
# with reference to the original KeyError("foo")
|
# with reference to the original KeyError("foo")
|
||||||
with pytest.raises(MQEngineDeadError) as execinfo:
|
with pytest.raises(MQEngineDeadError) as execinfo:
|
||||||
async for _ in client.generate(
|
async for _ in client.generate(
|
||||||
inputs="Hello my name is",
|
prompt="Hello my name is",
|
||||||
sampling_params=SamplingParams(max_tokens=10),
|
sampling_params=SamplingParams(max_tokens=10),
|
||||||
request_id=uuid.uuid4()):
|
request_id=uuid.uuid4()):
|
||||||
pass
|
pass
|
||||||
@ -183,7 +183,7 @@ async def test_bad_request(tmp_socket):
|
|||||||
|
|
||||||
# Invalid request should fail, but not crash the server.
|
# Invalid request should fail, but not crash the server.
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
async for _ in client.generate(inputs="Hello my name is",
|
async for _ in client.generate(prompt="Hello my name is",
|
||||||
sampling_params=SamplingParams(),
|
sampling_params=SamplingParams(),
|
||||||
request_id="abcd-1",
|
request_id="abcd-1",
|
||||||
lora_request=LoRARequest(
|
lora_request=LoRARequest(
|
||||||
@ -192,7 +192,7 @@ async def test_bad_request(tmp_socket):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# This request should be okay.
|
# This request should be okay.
|
||||||
async for _ in client.generate(inputs="Hello my name is",
|
async for _ in client.generate(prompt="Hello my name is",
|
||||||
sampling_params=SamplingParams(),
|
sampling_params=SamplingParams(),
|
||||||
request_id="abcd-2"):
|
request_id="abcd-2"):
|
||||||
pass
|
pass
|
||||||
|
@ -20,7 +20,7 @@ async def generate(
|
|||||||
count = 0
|
count = 0
|
||||||
async for out in client.generate(
|
async for out in client.generate(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
inputs="Hello my name is Robert and",
|
prompt="Hello my name is Robert and",
|
||||||
sampling_params=SamplingParams(max_tokens=num_tokens,
|
sampling_params=SamplingParams(max_tokens=num_tokens,
|
||||||
temperature=0)):
|
temperature=0)):
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|||||||
from vllm.engine.llm_engine import LLMEngine
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
from vllm.entrypoints.llm import LLM
|
from vllm.entrypoints.llm import LLM
|
||||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||||
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
|
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
from vllm.outputs import (CompletionOutput, EmbeddingOutput,
|
from vllm.outputs import (CompletionOutput, EmbeddingOutput,
|
||||||
EmbeddingRequestOutput, RequestOutput)
|
EmbeddingRequestOutput, RequestOutput)
|
||||||
@ -19,7 +19,7 @@ __all__ = [
|
|||||||
"__version_tuple__",
|
"__version_tuple__",
|
||||||
"LLM",
|
"LLM",
|
||||||
"ModelRegistry",
|
"ModelRegistry",
|
||||||
"PromptInputs",
|
"PromptType",
|
||||||
"TextPrompt",
|
"TextPrompt",
|
||||||
"TokensPrompt",
|
"TokensPrompt",
|
||||||
"SamplingParams",
|
"SamplingParams",
|
||||||
|
@ -2,8 +2,8 @@ import asyncio
|
|||||||
import time
|
import time
|
||||||
import weakref
|
import weakref
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
|
from typing import (Any, AsyncGenerator, Callable, Coroutine, Dict, Iterable,
|
||||||
Mapping, Optional, Set, Tuple, Type, Union)
|
List, Mapping, Optional, Set, Tuple, Type, Union, overload)
|
||||||
from weakref import ReferenceType
|
from weakref import ReferenceType
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
@ -17,7 +17,7 @@ from vllm.engine.metrics_types import StatLoggerBase
|
|||||||
from vllm.executor.executor_base import ExecutorAsyncBase
|
from vllm.executor.executor_base import ExecutorAsyncBase
|
||||||
from vllm.executor.gpu_executor import GPUExecutorAsync
|
from vllm.executor.gpu_executor import GPUExecutorAsync
|
||||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||||
from vllm.inputs import PromptInputs
|
from vllm.inputs import PromptType
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
@ -28,7 +28,7 @@ from vllm.sampling_params import SamplingParams
|
|||||||
from vllm.sequence import ExecuteModelRequest
|
from vllm.sequence import ExecuteModelRequest
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import weak_bind
|
from vllm.utils import deprecate_kwargs, weak_bind
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
|
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
|
||||||
@ -402,17 +402,54 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
"""Stop the remote worker execution loop."""
|
"""Stop the remote worker execution loop."""
|
||||||
await self.model_executor.stop_remote_worker_execution_loop_async()
|
await self.model_executor.stop_remote_worker_execution_loop_async()
|
||||||
|
|
||||||
|
@overload # DEPRECATED
|
||||||
async def add_request_async(
|
async def add_request_async(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
inputs: PromptInputs,
|
*,
|
||||||
|
inputs: PromptType,
|
||||||
params: Union[SamplingParams, PoolingParams],
|
params: Union[SamplingParams, PoolingParams],
|
||||||
arrival_time: Optional[float] = None,
|
arrival_time: Optional[float] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
async def add_request_async(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
prompt: PromptType,
|
||||||
|
params: Union[SamplingParams, PoolingParams],
|
||||||
|
arrival_time: Optional[float] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
@deprecate_kwargs(
|
||||||
|
"inputs",
|
||||||
|
additional_message="Please use the 'prompt' parameter instead.",
|
||||||
|
)
|
||||||
|
async def add_request_async(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
prompt: Optional[PromptType] = None,
|
||||||
|
params: Optional[Union[SamplingParams, PoolingParams]] = None,
|
||||||
|
arrival_time: Optional[float] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
*,
|
||||||
|
inputs: Optional[PromptType] = None, # DEPRECATED
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Async version of :meth:`add_request`."""
|
"""Async version of :meth:`add_request`."""
|
||||||
|
if inputs is not None:
|
||||||
|
prompt = inputs
|
||||||
|
assert prompt is not None and params is not None
|
||||||
|
|
||||||
if lora_request is not None and not self.lora_config:
|
if lora_request is not None and not self.lora_config:
|
||||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||||
"not enabled!")
|
"not enabled!")
|
||||||
@ -420,7 +457,7 @@ class _AsyncLLMEngine(LLMEngine):
|
|||||||
arrival_time = time.time()
|
arrival_time = time.time()
|
||||||
|
|
||||||
preprocessed_inputs = await self.input_preprocessor.preprocess_async(
|
preprocessed_inputs = await self.input_preprocessor.preprocess_async(
|
||||||
inputs,
|
prompt,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
prompt_adapter_request=prompt_adapter_request,
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
@ -774,16 +811,55 @@ class AsyncLLMEngine:
|
|||||||
|
|
||||||
# This method does not need to be async, but kept that way
|
# This method does not need to be async, but kept that way
|
||||||
# for backwards compatibility.
|
# for backwards compatibility.
|
||||||
async def add_request(
|
@overload # DEPRECATED
|
||||||
|
def add_request(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
inputs: PromptInputs,
|
*,
|
||||||
|
inputs: PromptType,
|
||||||
params: Union[SamplingParams, PoolingParams],
|
params: Union[SamplingParams, PoolingParams],
|
||||||
arrival_time: Optional[float] = None,
|
arrival_time: Optional[float] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
) -> Coroutine[None, None, AsyncGenerator[Union[
|
||||||
|
RequestOutput, EmbeddingRequestOutput], None]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def add_request(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
prompt: PromptType,
|
||||||
|
params: Union[SamplingParams, PoolingParams],
|
||||||
|
arrival_time: Optional[float] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
) -> Coroutine[None, None, AsyncGenerator[Union[
|
||||||
|
RequestOutput, EmbeddingRequestOutput], None]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@deprecate_kwargs(
|
||||||
|
"inputs",
|
||||||
|
additional_message="Please use the 'prompt' parameter instead.",
|
||||||
|
)
|
||||||
|
async def add_request(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
prompt: Optional[PromptType] = None,
|
||||||
|
params: Optional[Union[SamplingParams, PoolingParams]] = None,
|
||||||
|
arrival_time: Optional[float] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
*,
|
||||||
|
inputs: Optional[PromptType] = None, # DEPRECATED
|
||||||
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
|
) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]:
|
||||||
|
if inputs is not None:
|
||||||
|
prompt = inputs
|
||||||
|
assert prompt is not None and params is not None
|
||||||
|
|
||||||
if not self.is_running:
|
if not self.is_running:
|
||||||
if self.start_engine_loop:
|
if self.start_engine_loop:
|
||||||
self.start_background_loop()
|
self.start_background_loop()
|
||||||
@ -797,7 +873,7 @@ class AsyncLLMEngine:
|
|||||||
stream = self._request_tracker.add_request(
|
stream = self._request_tracker.add_request(
|
||||||
request_id,
|
request_id,
|
||||||
verbose=self.log_requests,
|
verbose=self.log_requests,
|
||||||
inputs=inputs,
|
prompt=prompt,
|
||||||
params=params,
|
params=params,
|
||||||
arrival_time=arrival_time or time.time(),
|
arrival_time=arrival_time or time.time(),
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
@ -808,7 +884,7 @@ class AsyncLLMEngine:
|
|||||||
|
|
||||||
async def generate(
|
async def generate(
|
||||||
self,
|
self,
|
||||||
inputs: PromptInputs,
|
prompt: PromptType,
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
@ -822,8 +898,7 @@ class AsyncLLMEngine:
|
|||||||
from the LLMEngine to the caller.
|
from the LLMEngine to the caller.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs: The inputs to the LLM. See
|
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
|
||||||
:class:`~vllm.inputs.PromptInputs`
|
|
||||||
for more details about the format of each input.
|
for more details about the format of each input.
|
||||||
sampling_params: The sampling parameters of the request.
|
sampling_params: The sampling parameters of the request.
|
||||||
request_id: The unique id of the request.
|
request_id: The unique id of the request.
|
||||||
@ -881,7 +956,7 @@ class AsyncLLMEngine:
|
|||||||
"""
|
"""
|
||||||
async for output in await self.add_request(
|
async for output in await self.add_request(
|
||||||
request_id,
|
request_id,
|
||||||
inputs,
|
prompt,
|
||||||
sampling_params,
|
sampling_params,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
trace_headers=trace_headers,
|
trace_headers=trace_headers,
|
||||||
@ -891,7 +966,7 @@ class AsyncLLMEngine:
|
|||||||
|
|
||||||
async def encode(
|
async def encode(
|
||||||
self,
|
self,
|
||||||
inputs: PromptInputs,
|
prompt: PromptType,
|
||||||
pooling_params: PoolingParams,
|
pooling_params: PoolingParams,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
@ -904,8 +979,7 @@ class AsyncLLMEngine:
|
|||||||
from the LLMEngine to the caller.
|
from the LLMEngine to the caller.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs: The inputs to the LLM. See
|
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
|
||||||
:class:`~vllm.inputs.PromptInputs`
|
|
||||||
for more details about the format of each input.
|
for more details about the format of each input.
|
||||||
pooling_params: The pooling parameters of the request.
|
pooling_params: The pooling parameters of the request.
|
||||||
request_id: The unique id of the request.
|
request_id: The unique id of the request.
|
||||||
@ -959,7 +1033,7 @@ class AsyncLLMEngine:
|
|||||||
"""
|
"""
|
||||||
async for output in await self.add_request(
|
async for output in await self.add_request(
|
||||||
request_id,
|
request_id,
|
||||||
inputs,
|
prompt,
|
||||||
pooling_params,
|
pooling_params,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
trace_headers=trace_headers,
|
trace_headers=trace_headers,
|
||||||
|
@ -6,7 +6,7 @@ from functools import partial
|
|||||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
|
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
|
||||||
Iterable, List, Mapping, NamedTuple, Optional)
|
Iterable, List, Mapping, NamedTuple, Optional)
|
||||||
from typing import Sequence as GenericSequence
|
from typing import Sequence as GenericSequence
|
||||||
from typing import Set, Type, Union
|
from typing import Set, Type, Union, overload
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing_extensions import TypeVar
|
from typing_extensions import TypeVar
|
||||||
@ -29,7 +29,7 @@ from vllm.executor.executor_base import ExecutorBase
|
|||||||
from vllm.executor.gpu_executor import GPUExecutor
|
from vllm.executor.gpu_executor import GPUExecutor
|
||||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||||
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
|
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
|
||||||
InputRegistry, LLMInputs, PromptInputs)
|
InputRegistry, LLMInputs, PromptType)
|
||||||
from vllm.inputs.preprocess import InputPreprocessor
|
from vllm.inputs.preprocess import InputPreprocessor
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -51,7 +51,7 @@ from vllm.transformers_utils.tokenizer_group import (
|
|||||||
BaseTokenizerGroup, init_tokenizer_from_configs)
|
BaseTokenizerGroup, init_tokenizer_from_configs)
|
||||||
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||||
usage_message)
|
usage_message)
|
||||||
from vllm.utils import Counter, Device, weak_bind
|
from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind
|
||||||
from vllm.version import __version__ as VLLM_VERSION
|
from vllm.version import __version__ as VLLM_VERSION
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -689,16 +689,51 @@ class LLMEngine:
|
|||||||
def stop_remote_worker_execution_loop(self) -> None:
|
def stop_remote_worker_execution_loop(self) -> None:
|
||||||
self.model_executor.stop_remote_worker_execution_loop()
|
self.model_executor.stop_remote_worker_execution_loop()
|
||||||
|
|
||||||
|
@overload # DEPRECATED
|
||||||
def add_request(
|
def add_request(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
inputs: PromptInputs,
|
*,
|
||||||
|
inputs: PromptType,
|
||||||
params: Union[SamplingParams, PoolingParams],
|
params: Union[SamplingParams, PoolingParams],
|
||||||
arrival_time: Optional[float] = None,
|
arrival_time: Optional[float] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
|
) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def add_request(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
prompt: PromptType,
|
||||||
|
params: Union[SamplingParams, PoolingParams],
|
||||||
|
arrival_time: Optional[float] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
priority: int = 0,
|
||||||
|
) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
@deprecate_kwargs(
|
||||||
|
"inputs",
|
||||||
|
additional_message="Please use the 'prompt' parameter instead.",
|
||||||
|
)
|
||||||
|
def add_request(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
prompt: Optional[PromptType] = None,
|
||||||
|
params: Optional[Union[SamplingParams, PoolingParams]] = None,
|
||||||
|
arrival_time: Optional[float] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
priority: int = 0,
|
||||||
|
*,
|
||||||
|
inputs: Optional[PromptType] = None, # DEPRECATED
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add a request to the engine's request pool.
|
"""Add a request to the engine's request pool.
|
||||||
|
|
||||||
@ -708,8 +743,7 @@ class LLMEngine:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
request_id: The unique ID of the request.
|
request_id: The unique ID of the request.
|
||||||
inputs: The inputs to the LLM. See
|
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
|
||||||
:class:`~vllm.inputs.PromptInputs`
|
|
||||||
for more details about the format of each input.
|
for more details about the format of each input.
|
||||||
params: Parameters for sampling or pooling.
|
params: Parameters for sampling or pooling.
|
||||||
:class:`~vllm.SamplingParams` for text generation.
|
:class:`~vllm.SamplingParams` for text generation.
|
||||||
@ -744,6 +778,10 @@ class LLMEngine:
|
|||||||
>>> # continue the request processing
|
>>> # continue the request processing
|
||||||
>>> ...
|
>>> ...
|
||||||
"""
|
"""
|
||||||
|
if inputs is not None:
|
||||||
|
prompt = inputs
|
||||||
|
assert prompt is not None and params is not None
|
||||||
|
|
||||||
if lora_request is not None and not self.lora_config:
|
if lora_request is not None and not self.lora_config:
|
||||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||||
"not enabled!")
|
"not enabled!")
|
||||||
@ -756,7 +794,7 @@ class LLMEngine:
|
|||||||
arrival_time = time.time()
|
arrival_time = time.time()
|
||||||
|
|
||||||
preprocessed_inputs = self.input_preprocessor.preprocess(
|
preprocessed_inputs = self.input_preprocessor.preprocess(
|
||||||
inputs,
|
prompt,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
prompt_adapter_request=prompt_adapter_request,
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
|
@ -1,13 +1,14 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Mapping, Optional, Union
|
from typing import List, Mapping, Optional, Union, overload
|
||||||
|
|
||||||
from vllm import PoolingParams
|
from vllm import PoolingParams
|
||||||
from vllm.inputs import PromptInputs
|
from vllm.inputs import PromptType
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.utils import deprecate_kwargs
|
||||||
|
|
||||||
VLLM_RPC_SUCCESS_STR = "SUCCESS"
|
VLLM_RPC_SUCCESS_STR = "SUCCESS"
|
||||||
|
|
||||||
@ -23,13 +24,67 @@ class MQEngineDeadError(RuntimeError):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RPCProcessRequest:
|
class RPCProcessRequest:
|
||||||
inputs: PromptInputs
|
prompt: PromptType
|
||||||
params: Union[SamplingParams, PoolingParams]
|
params: Union[SamplingParams, PoolingParams]
|
||||||
request_id: str
|
request_id: str
|
||||||
lora_request: Optional[LoRARequest] = None
|
lora_request: Optional[LoRARequest] = None
|
||||||
trace_headers: Optional[Mapping[str, str]] = None
|
trace_headers: Optional[Mapping[str, str]] = None
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||||
|
|
||||||
|
@overload # DEPRECATED
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
inputs: PromptType,
|
||||||
|
params: Union[SamplingParams, PoolingParams],
|
||||||
|
request_id: str,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prompt: PromptType,
|
||||||
|
params: Union[SamplingParams, PoolingParams],
|
||||||
|
request_id: str,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
@deprecate_kwargs(
|
||||||
|
"inputs",
|
||||||
|
additional_message="Please use the 'prompt' parameter instead.",
|
||||||
|
)
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prompt: Optional[PromptType] = None,
|
||||||
|
params: Optional[Union[SamplingParams, PoolingParams]] = None,
|
||||||
|
request_id: Optional[str] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
*,
|
||||||
|
inputs: Optional[PromptType] = None, # DEPRECATED
|
||||||
|
) -> None:
|
||||||
|
if inputs is not None:
|
||||||
|
prompt = inputs
|
||||||
|
assert (prompt is not None and params is not None
|
||||||
|
and request_id is not None)
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.prompt = prompt
|
||||||
|
self.params = params
|
||||||
|
self.request_id = request_id
|
||||||
|
self.lora_request = lora_request
|
||||||
|
self.trace_headers = trace_headers
|
||||||
|
self.prompt_adapter_request = prompt_adapter_request
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RPCError:
|
class RPCError:
|
||||||
|
@ -3,7 +3,7 @@ import copy
|
|||||||
import pickle
|
import pickle
|
||||||
from contextlib import contextmanager, suppress
|
from contextlib import contextmanager, suppress
|
||||||
from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional,
|
from typing import (Any, AsyncGenerator, Dict, Iterator, Mapping, Optional,
|
||||||
Union)
|
Union, overload)
|
||||||
|
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
import zmq
|
import zmq
|
||||||
@ -24,13 +24,14 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
|||||||
RPCStartupRequest, RPCStartupResponse)
|
RPCStartupRequest, RPCStartupResponse)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.envs import VLLM_RPC_TIMEOUT
|
from vllm.envs import VLLM_RPC_TIMEOUT
|
||||||
from vllm.inputs import PromptInputs
|
from vllm.inputs import PromptType
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||||
|
from vllm.utils import deprecate_kwargs
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -366,14 +367,45 @@ class MQLLMEngineClient:
|
|||||||
def dead_error(self) -> BaseException:
|
def dead_error(self) -> BaseException:
|
||||||
return ENGINE_DEAD_ERROR(self._errored_with)
|
return ENGINE_DEAD_ERROR(self._errored_with)
|
||||||
|
|
||||||
|
@overload # DEPRECATED
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
inputs: PromptInputs,
|
*,
|
||||||
|
inputs: PromptType,
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
) -> AsyncGenerator[RequestOutput, None]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
prompt: PromptType,
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
request_id: str,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
) -> AsyncGenerator[RequestOutput, None]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@deprecate_kwargs(
|
||||||
|
"inputs",
|
||||||
|
additional_message="Please use the 'prompt' parameter instead.",
|
||||||
|
)
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
prompt: Optional[PromptType] = None,
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
request_id: Optional[str] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
|
*,
|
||||||
|
inputs: Optional[PromptType] = None # DEPRECATED
|
||||||
) -> AsyncGenerator[RequestOutput, None]:
|
) -> AsyncGenerator[RequestOutput, None]:
|
||||||
"""Generate outputs for a request.
|
"""Generate outputs for a request.
|
||||||
|
|
||||||
@ -382,8 +414,7 @@ class MQLLMEngineClient:
|
|||||||
from the LLMEngine to the caller.
|
from the LLMEngine to the caller.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs: The inputs to the LLM. See
|
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
|
||||||
:class:`~vllm.inputs.PromptInputs`
|
|
||||||
for more details about the format of each input.
|
for more details about the format of each input.
|
||||||
sampling_params: The sampling parameters of the request.
|
sampling_params: The sampling parameters of the request.
|
||||||
request_id: The unique id of the request.
|
request_id: The unique id of the request.
|
||||||
@ -392,17 +423,51 @@ class MQLLMEngineClient:
|
|||||||
prompt_adapter_request: Prompt Adapter request to use
|
prompt_adapter_request: Prompt Adapter request to use
|
||||||
for generation, if any.
|
for generation, if any.
|
||||||
"""
|
"""
|
||||||
return self._process_request(inputs, sampling_params, request_id,
|
if inputs is not None:
|
||||||
|
prompt = inputs
|
||||||
|
assert (prompt is not None and sampling_params is not None
|
||||||
|
and request_id is not None)
|
||||||
|
|
||||||
|
return self._process_request(prompt, sampling_params, request_id,
|
||||||
lora_request, trace_headers,
|
lora_request, trace_headers,
|
||||||
prompt_adapter_request)
|
prompt_adapter_request)
|
||||||
|
|
||||||
|
@overload # DEPRECATED
|
||||||
def encode(
|
def encode(
|
||||||
self,
|
self,
|
||||||
inputs: PromptInputs,
|
*,
|
||||||
|
inputs: PromptType,
|
||||||
pooling_params: PoolingParams,
|
pooling_params: PoolingParams,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
|
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def encode(
|
||||||
|
self,
|
||||||
|
prompt: PromptType,
|
||||||
|
pooling_params: PoolingParams,
|
||||||
|
request_id: str,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
|
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@deprecate_kwargs(
|
||||||
|
"inputs",
|
||||||
|
additional_message="Please use the 'prompt' parameter instead.",
|
||||||
|
)
|
||||||
|
def encode(
|
||||||
|
self,
|
||||||
|
prompt: Optional[PromptType] = None,
|
||||||
|
pooling_params: Optional[PoolingParams] = None,
|
||||||
|
request_id: Optional[str] = None,
|
||||||
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
|
*,
|
||||||
|
inputs: Optional[PromptType] = None # DEPRECATED
|
||||||
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
) -> AsyncGenerator[EmbeddingRequestOutput, None]:
|
||||||
"""Generate outputs for a request from an embedding model.
|
"""Generate outputs for a request from an embedding model.
|
||||||
|
|
||||||
@ -411,8 +476,7 @@ class MQLLMEngineClient:
|
|||||||
from the LLMEngine to the caller.
|
from the LLMEngine to the caller.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs: The inputs to the LLM. See
|
prompt: The prompt to the LLM. See :class:`~vllm.inputs.PromptType`
|
||||||
:class:`~vllm.inputs.PromptInputs`
|
|
||||||
for more details about the format of each input.
|
for more details about the format of each input.
|
||||||
pooling_params: The pooling parameters of the request.
|
pooling_params: The pooling parameters of the request.
|
||||||
request_id: The unique id of the request.
|
request_id: The unique id of the request.
|
||||||
@ -423,12 +487,17 @@ class MQLLMEngineClient:
|
|||||||
The output `EmbeddingRequestOutput` objects from the LLMEngine
|
The output `EmbeddingRequestOutput` objects from the LLMEngine
|
||||||
for the request.
|
for the request.
|
||||||
"""
|
"""
|
||||||
return self._process_request(inputs, pooling_params, request_id,
|
if inputs is not None:
|
||||||
|
prompt = inputs
|
||||||
|
assert (prompt is not None and pooling_params is not None
|
||||||
|
and request_id is not None)
|
||||||
|
|
||||||
|
return self._process_request(prompt, pooling_params, request_id,
|
||||||
lora_request, trace_headers)
|
lora_request, trace_headers)
|
||||||
|
|
||||||
async def _process_request(
|
async def _process_request(
|
||||||
self,
|
self,
|
||||||
inputs: PromptInputs,
|
prompt: PromptType,
|
||||||
params: Union[SamplingParams, PoolingParams],
|
params: Union[SamplingParams, PoolingParams],
|
||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
@ -461,7 +530,7 @@ class MQLLMEngineClient:
|
|||||||
|
|
||||||
request_bytes = pickle.dumps(
|
request_bytes = pickle.dumps(
|
||||||
RPCProcessRequest(
|
RPCProcessRequest(
|
||||||
inputs=inputs,
|
prompt=prompt,
|
||||||
params=params,
|
params=params,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
|
@ -271,7 +271,7 @@ class MQLLMEngine:
|
|||||||
try:
|
try:
|
||||||
self.engine.add_request(
|
self.engine.add_request(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
inputs=request.inputs,
|
prompt=request.prompt,
|
||||||
params=request.params,
|
params=request.params,
|
||||||
lora_request=request.lora_request,
|
lora_request=request.lora_request,
|
||||||
trace_headers=request.trace_headers,
|
trace_headers=request.trace_headers,
|
||||||
|
@ -3,7 +3,7 @@ from typing import (AsyncGenerator, List, Mapping, Optional, Protocol,
|
|||||||
|
|
||||||
from vllm.config import DecodingConfig, ModelConfig
|
from vllm.config import DecodingConfig, ModelConfig
|
||||||
from vllm.core.scheduler import SchedulerOutputs
|
from vllm.core.scheduler import SchedulerOutputs
|
||||||
from vllm.inputs.data import PromptInputs
|
from vllm.inputs.data import PromptType
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||||
@ -35,19 +35,19 @@ class EngineClient(Protocol):
|
|||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
inputs: PromptInputs,
|
prompt: PromptType,
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
trace_headers: Optional[Mapping[str, str]] = None,
|
trace_headers: Optional[Mapping[str, str]] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||||
) -> AsyncGenerator[RequestOutput, None]:
|
) -> AsyncGenerator[RequestOutput, None]:
|
||||||
"""Generates outputs for a request"""
|
"""Generate outputs for a request."""
|
||||||
...
|
...
|
||||||
|
|
||||||
def encode(
|
def encode(
|
||||||
self,
|
self,
|
||||||
inputs: PromptInputs,
|
prompt: PromptType,
|
||||||
pooling_params: PoolingParams,
|
pooling_params: PoolingParams,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
|
@ -12,7 +12,7 @@ from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
|||||||
apply_hf_chat_template,
|
apply_hf_chat_template,
|
||||||
apply_mistral_chat_template,
|
apply_mistral_chat_template,
|
||||||
parse_chat_messages)
|
parse_chat_messages)
|
||||||
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
|
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
|
||||||
from vllm.inputs.parse import parse_and_batch_prompt
|
from vllm.inputs.parse import parse_and_batch_prompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -293,8 +293,8 @@ class LLM:
|
|||||||
@overload
|
@overload
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
inputs: Union[PromptInputs, Sequence[PromptInputs]],
|
prompts: Union[PromptType, Sequence[PromptType]],
|
||||||
/, # We may enable `inputs` keyword after removing the old API
|
/,
|
||||||
*,
|
*,
|
||||||
sampling_params: Optional[Union[SamplingParams,
|
sampling_params: Optional[Union[SamplingParams,
|
||||||
Sequence[SamplingParams]]] = None,
|
Sequence[SamplingParams]]] = None,
|
||||||
@ -304,14 +304,13 @@ class LLM:
|
|||||||
...
|
...
|
||||||
|
|
||||||
@deprecate_kwargs(
|
@deprecate_kwargs(
|
||||||
"prompts",
|
|
||||||
"prompt_token_ids",
|
"prompt_token_ids",
|
||||||
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
|
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
|
||||||
additional_message="Please use the 'inputs' parameter instead.",
|
additional_message="Please use the 'prompts' parameter instead.",
|
||||||
)
|
)
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
|
prompts: Union[Union[PromptType, Sequence[PromptType]],
|
||||||
Optional[Union[str, List[str]]]] = None,
|
Optional[Union[str, List[str]]]] = None,
|
||||||
sampling_params: Optional[Union[SamplingParams,
|
sampling_params: Optional[Union[SamplingParams,
|
||||||
Sequence[SamplingParams]]] = None,
|
Sequence[SamplingParams]]] = None,
|
||||||
@ -330,7 +329,9 @@ class LLM:
|
|||||||
into a single list and pass it to this method.
|
into a single list and pass it to this method.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs: A list of inputs to generate completions for.
|
prompts: The prompts to the LLM. You may pass a sequence of prompts
|
||||||
|
for batch inference. See :class:`~vllm.inputs.PromptType`
|
||||||
|
for more details about the format of each prompts.
|
||||||
sampling_params: The sampling parameters for text generation. If
|
sampling_params: The sampling parameters for text generation. If
|
||||||
None, we use the default sampling parameters.
|
None, we use the default sampling parameters.
|
||||||
When it is a single value, it is applied to every prompt.
|
When it is a single value, it is applied to every prompt.
|
||||||
@ -358,12 +359,13 @@ class LLM:
|
|||||||
"models (XForCausalLM, XForConditionalGeneration).")
|
"models (XForCausalLM, XForConditionalGeneration).")
|
||||||
|
|
||||||
if prompt_token_ids is not None:
|
if prompt_token_ids is not None:
|
||||||
inputs = self._convert_v1_inputs(
|
parsed_prompts = self._convert_v1_inputs(
|
||||||
prompts=cast(Optional[Union[str, List[str]]], prompts),
|
prompts=cast(Optional[Union[str, List[str]]], prompts),
|
||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
|
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
|
||||||
|
prompts)
|
||||||
|
|
||||||
if isinstance(guided_options_request, dict):
|
if isinstance(guided_options_request, dict):
|
||||||
if len(guided_options_request) > 1:
|
if len(guided_options_request) > 1:
|
||||||
@ -378,7 +380,7 @@ class LLM:
|
|||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
|
|
||||||
self._validate_and_add_requests(
|
self._validate_and_add_requests(
|
||||||
inputs=inputs,
|
prompts=parsed_prompts,
|
||||||
params=sampling_params,
|
params=sampling_params,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
prompt_adapter_request=prompt_adapter_request,
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
@ -648,8 +650,8 @@ class LLM:
|
|||||||
@overload
|
@overload
|
||||||
def encode(
|
def encode(
|
||||||
self,
|
self,
|
||||||
inputs: Union[PromptInputs, Sequence[PromptInputs]],
|
prompts: Union[PromptType, Sequence[PromptType]],
|
||||||
/, # We may enable `inputs` keyword after removing the old API
|
/,
|
||||||
*,
|
*,
|
||||||
pooling_params: Optional[Union[PoolingParams,
|
pooling_params: Optional[Union[PoolingParams,
|
||||||
Sequence[PoolingParams]]] = None,
|
Sequence[PoolingParams]]] = None,
|
||||||
@ -659,14 +661,13 @@ class LLM:
|
|||||||
...
|
...
|
||||||
|
|
||||||
@deprecate_kwargs(
|
@deprecate_kwargs(
|
||||||
"prompts",
|
|
||||||
"prompt_token_ids",
|
"prompt_token_ids",
|
||||||
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
|
is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
|
||||||
additional_message="Please use the 'inputs' parameter instead.",
|
additional_message="Please use the 'prompts' parameter instead.",
|
||||||
)
|
)
|
||||||
def encode(
|
def encode(
|
||||||
self,
|
self,
|
||||||
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
|
prompts: Union[Union[PromptType, Sequence[PromptType]],
|
||||||
Optional[Union[str, List[str]]]] = None,
|
Optional[Union[str, List[str]]]] = None,
|
||||||
pooling_params: Optional[Union[PoolingParams,
|
pooling_params: Optional[Union[PoolingParams,
|
||||||
Sequence[PoolingParams]]] = None,
|
Sequence[PoolingParams]]] = None,
|
||||||
@ -682,9 +683,9 @@ class LLM:
|
|||||||
into a single list and pass it to this method.
|
into a single list and pass it to this method.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs: The inputs to the LLM. You may pass a sequence of inputs for
|
prompts: The prompts to the LLM. You may pass a sequence of prompts
|
||||||
batch inference. See :class:`~vllm.inputs.PromptInputs`
|
for batch inference. See :class:`~vllm.inputs.PromptType`
|
||||||
for more details about the format of each input.
|
for more details about the format of each prompts.
|
||||||
pooling_params: The pooling parameters for pooling. If None, we
|
pooling_params: The pooling parameters for pooling. If None, we
|
||||||
use the default pooling parameters.
|
use the default pooling parameters.
|
||||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||||
@ -707,19 +708,20 @@ class LLM:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if prompt_token_ids is not None:
|
if prompt_token_ids is not None:
|
||||||
inputs = self._convert_v1_inputs(
|
parsed_prompts = self._convert_v1_inputs(
|
||||||
prompts=cast(Optional[Union[str, List[str]]], prompts),
|
prompts=cast(Optional[Union[str, List[str]]], prompts),
|
||||||
prompt_token_ids=prompt_token_ids,
|
prompt_token_ids=prompt_token_ids,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
|
parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
|
||||||
|
prompts)
|
||||||
|
|
||||||
if pooling_params is None:
|
if pooling_params is None:
|
||||||
# Use default pooling params.
|
# Use default pooling params.
|
||||||
pooling_params = PoolingParams()
|
pooling_params = PoolingParams()
|
||||||
|
|
||||||
self._validate_and_add_requests(
|
self._validate_and_add_requests(
|
||||||
inputs=inputs,
|
prompts=parsed_prompts,
|
||||||
params=pooling_params,
|
params=pooling_params,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
prompt_adapter_request=prompt_adapter_request,
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
@ -763,9 +765,9 @@ class LLM:
|
|||||||
raise ValueError("Either prompts or prompt_token_ids must be "
|
raise ValueError("Either prompts or prompt_token_ids must be "
|
||||||
"provided.")
|
"provided.")
|
||||||
|
|
||||||
inputs: List[PromptInputs] = []
|
parsed_prompts: List[PromptType] = []
|
||||||
for i in range(num_requests):
|
for i in range(num_requests):
|
||||||
item: PromptInputs
|
item: PromptType
|
||||||
|
|
||||||
if prompts is not None:
|
if prompts is not None:
|
||||||
item = TextPrompt(prompt=prompts[i])
|
item = TextPrompt(prompt=prompts[i])
|
||||||
@ -774,13 +776,13 @@ class LLM:
|
|||||||
else:
|
else:
|
||||||
raise AssertionError
|
raise AssertionError
|
||||||
|
|
||||||
inputs.append(item)
|
parsed_prompts.append(item)
|
||||||
|
|
||||||
return inputs
|
return parsed_prompts
|
||||||
|
|
||||||
def _validate_and_add_requests(
|
def _validate_and_add_requests(
|
||||||
self,
|
self,
|
||||||
inputs: Union[PromptInputs, Sequence[PromptInputs]],
|
prompts: Union[PromptType, Sequence[PromptType]],
|
||||||
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
|
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
|
||||||
Sequence[PoolingParams]],
|
Sequence[PoolingParams]],
|
||||||
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
|
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
|
||||||
@ -788,11 +790,11 @@ class LLM:
|
|||||||
guided_options: Optional[GuidedDecodingRequest] = None,
|
guided_options: Optional[GuidedDecodingRequest] = None,
|
||||||
priority: Optional[List[int]] = None,
|
priority: Optional[List[int]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if isinstance(inputs, (str, dict)):
|
if isinstance(prompts, (str, dict)):
|
||||||
# Convert a single prompt to a list.
|
# Convert a single prompt to a list.
|
||||||
inputs = [inputs]
|
prompts = [prompts]
|
||||||
|
|
||||||
num_requests = len(inputs)
|
num_requests = len(prompts)
|
||||||
if isinstance(params, list) and len(params) != num_requests:
|
if isinstance(params, list) and len(params) != num_requests:
|
||||||
raise ValueError("The lengths of prompts and params "
|
raise ValueError("The lengths of prompts and params "
|
||||||
"must be the same.")
|
"must be the same.")
|
||||||
@ -809,9 +811,9 @@ class LLM:
|
|||||||
sp.output_kind = RequestOutputKind.FINAL_ONLY
|
sp.output_kind = RequestOutputKind.FINAL_ONLY
|
||||||
|
|
||||||
# Add requests to the engine.
|
# Add requests to the engine.
|
||||||
for i, request_inputs in enumerate(inputs):
|
for i, prompt in enumerate(prompts):
|
||||||
self._add_request(
|
self._add_request(
|
||||||
request_inputs,
|
prompt,
|
||||||
params[i] if isinstance(params, Sequence) else params,
|
params[i] if isinstance(params, Sequence) else params,
|
||||||
lora_request=lora_request[i] if isinstance(
|
lora_request=lora_request[i] if isinstance(
|
||||||
lora_request, Sequence) else lora_request,
|
lora_request, Sequence) else lora_request,
|
||||||
@ -821,7 +823,7 @@ class LLM:
|
|||||||
|
|
||||||
def _add_request(
|
def _add_request(
|
||||||
self,
|
self,
|
||||||
inputs: PromptInputs,
|
prompt: PromptType,
|
||||||
params: Union[SamplingParams, PoolingParams],
|
params: Union[SamplingParams, PoolingParams],
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
@ -830,7 +832,7 @@ class LLM:
|
|||||||
request_id = str(next(self.request_counter))
|
request_id = str(next(self.request_counter))
|
||||||
self.llm_engine.add_request(
|
self.llm_engine.add_request(
|
||||||
request_id,
|
request_id,
|
||||||
inputs,
|
prompt,
|
||||||
params,
|
params,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
prompt_adapter_request=prompt_adapter_request,
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
|
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
|
||||||
LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt,
|
LLMInputs, PromptType, SingletonPrompt, TextPrompt,
|
||||||
TokensPrompt, build_explicit_enc_dec_prompt,
|
TokensPrompt, build_explicit_enc_dec_prompt,
|
||||||
to_enc_dec_tuple_list, zip_enc_dec_prompts)
|
to_enc_dec_tuple_list, zip_enc_dec_prompts)
|
||||||
from .registry import InputContext, InputRegistry
|
from .registry import InputContext, InputRegistry
|
||||||
@ -16,8 +16,8 @@ See also:
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"TextPrompt",
|
"TextPrompt",
|
||||||
"TokensPrompt",
|
"TokensPrompt",
|
||||||
"PromptInputs",
|
"PromptType",
|
||||||
"SingletonPromptInputs",
|
"SingletonPrompt",
|
||||||
"ExplicitEncoderDecoderPrompt",
|
"ExplicitEncoderDecoderPrompt",
|
||||||
"LLMInputs",
|
"LLMInputs",
|
||||||
"EncoderDecoderLLMInputs",
|
"EncoderDecoderLLMInputs",
|
||||||
@ -28,3 +28,17 @@ __all__ = [
|
|||||||
"InputContext",
|
"InputContext",
|
||||||
"InputRegistry",
|
"InputRegistry",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name: str):
|
||||||
|
if name == "PromptInput":
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
msg = ("PromptInput has been renamed to PromptType. "
|
||||||
|
"The original name will be removed in an upcoming version.")
|
||||||
|
|
||||||
|
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
||||||
|
|
||||||
|
return PromptType
|
||||||
|
|
||||||
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||||
|
@ -33,7 +33,7 @@ class TokensPrompt(TypedDict):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt]
|
SingletonPrompt = Union[str, TextPrompt, TokensPrompt]
|
||||||
"""
|
"""
|
||||||
Set of possible schemas for a single LLM input:
|
Set of possible schemas for a single LLM input:
|
||||||
|
|
||||||
@ -46,7 +46,7 @@ which may be utilized for encoder/decoder models when
|
|||||||
the user desires to express both the encoder & decoder
|
the user desires to express both the encoder & decoder
|
||||||
prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt`
|
prompts explicitly, i.e. :class:`ExplicitEncoderDecoderPrompt`
|
||||||
|
|
||||||
A prompt of type :class:`SingletonPromptInputs` may be employed
|
A prompt of type :class:`SingletonPrompt` may be employed
|
||||||
as (1) input to a decoder-only model, (2) input to
|
as (1) input to a decoder-only model, (2) input to
|
||||||
the encoder of an encoder/decoder model, in the scenario
|
the encoder of an encoder/decoder model, in the scenario
|
||||||
where the decoder-prompt is not specified explicitly, or
|
where the decoder-prompt is not specified explicitly, or
|
||||||
@ -55,33 +55,33 @@ more than one prompt, i.e. :class:`ExplicitEncoderDecoderPrompt`
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
_T1_co = TypeVar("_T1_co",
|
_T1_co = TypeVar("_T1_co",
|
||||||
bound=SingletonPromptInputs,
|
bound=SingletonPrompt,
|
||||||
default=SingletonPromptInputs,
|
default=SingletonPrompt,
|
||||||
covariant=True)
|
covariant=True)
|
||||||
_T2_co = TypeVar("_T2_co",
|
_T2_co = TypeVar("_T2_co",
|
||||||
bound=SingletonPromptInputs,
|
bound=SingletonPrompt,
|
||||||
default=SingletonPromptInputs,
|
default=SingletonPrompt,
|
||||||
covariant=True)
|
covariant=True)
|
||||||
|
|
||||||
|
|
||||||
# TODO: Make fields ReadOnly once mypy supports it
|
# TODO: Make fields ReadOnly once mypy supports it
|
||||||
class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
|
class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
|
||||||
"""Represents an encoder/decoder model input prompt,
|
"""
|
||||||
comprising an explicit encoder prompt and a
|
Represents an encoder/decoder model input prompt,
|
||||||
decoder prompt.
|
comprising an explicit encoder prompt and a decoder prompt.
|
||||||
|
|
||||||
The encoder and decoder prompts, respectively,
|
The encoder and decoder prompts, respectively,
|
||||||
may formatted according to any of the
|
may formatted according to any of the
|
||||||
:class:`SingletonPromptInputs` schemas, and are not
|
:class:`SingletonPrompt` schemas, and are not
|
||||||
required to have the same schema.
|
required to have the same schema.
|
||||||
|
|
||||||
Only the encoder prompt may have multi-modal data.
|
Only the encoder prompt may have multi-modal data.
|
||||||
|
|
||||||
Note that an :class:`ExplicitEncoderDecoderPrompt` may not
|
Note that an :class:`ExplicitEncoderDecoderPrompt` may not
|
||||||
be used as an input to a decoder-only model,
|
be used as an input to a decoder-only model,
|
||||||
and that the `encoder_prompt` and `decoder_prompt`
|
and that the :code:`encoder_prompt` and :code:`decoder_prompt`
|
||||||
fields of this data structure themselves must be
|
fields of this data structure themselves must be
|
||||||
:class:`SingletonPromptInputs` instances.
|
:class:`SingletonPrompt` instances.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
encoder_prompt: _T1_co
|
encoder_prompt: _T1_co
|
||||||
@ -89,7 +89,7 @@ class ExplicitEncoderDecoderPrompt(TypedDict, Generic[_T1_co, _T2_co]):
|
|||||||
decoder_prompt: Optional[_T2_co]
|
decoder_prompt: Optional[_T2_co]
|
||||||
|
|
||||||
|
|
||||||
PromptInputs = Union[SingletonPromptInputs, ExplicitEncoderDecoderPrompt]
|
PromptType = Union[SingletonPrompt, ExplicitEncoderDecoderPrompt]
|
||||||
"""
|
"""
|
||||||
Set of possible schemas for an LLM input, including
|
Set of possible schemas for an LLM input, including
|
||||||
both decoder-only and encoder/decoder input types:
|
both decoder-only and encoder/decoder input types:
|
||||||
@ -140,12 +140,8 @@ class EncoderDecoderLLMInputs(LLMInputs):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
_T1 = TypeVar("_T1",
|
_T1 = TypeVar("_T1", bound=SingletonPrompt, default=SingletonPrompt)
|
||||||
bound=SingletonPromptInputs,
|
_T2 = TypeVar("_T2", bound=SingletonPrompt, default=SingletonPrompt)
|
||||||
default=SingletonPromptInputs)
|
|
||||||
_T2 = TypeVar("_T2",
|
|
||||||
bound=SingletonPromptInputs,
|
|
||||||
default=SingletonPromptInputs)
|
|
||||||
|
|
||||||
|
|
||||||
def build_explicit_enc_dec_prompt(
|
def build_explicit_enc_dec_prompt(
|
||||||
@ -176,3 +172,17 @@ def to_enc_dec_tuple_list(
|
|||||||
return [(enc_dec_prompt["encoder_prompt"],
|
return [(enc_dec_prompt["encoder_prompt"],
|
||||||
enc_dec_prompt["decoder_prompt"])
|
enc_dec_prompt["decoder_prompt"])
|
||||||
for enc_dec_prompt in enc_dec_prompts]
|
for enc_dec_prompt in enc_dec_prompts]
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name: str):
|
||||||
|
if name == "PromptInput":
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
msg = ("PromptInput has been renamed to PromptType. "
|
||||||
|
"The original name will be removed in an upcoming version.")
|
||||||
|
|
||||||
|
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
||||||
|
|
||||||
|
return PromptType
|
||||||
|
|
||||||
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||||
|
@ -5,7 +5,7 @@ from typing_extensions import TypeIs
|
|||||||
from vllm.utils import is_list_of
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
|
from .data import (EncoderDecoderLLMInputs, ExplicitEncoderDecoderPrompt,
|
||||||
LLMInputs, PromptInputs, SingletonPromptInputs, TextPrompt,
|
LLMInputs, PromptType, SingletonPrompt, TextPrompt,
|
||||||
TokensPrompt)
|
TokensPrompt)
|
||||||
|
|
||||||
|
|
||||||
@ -81,23 +81,23 @@ class ParsedTokensPrompt(TypedDict):
|
|||||||
|
|
||||||
|
|
||||||
def parse_singleton_prompt(
|
def parse_singleton_prompt(
|
||||||
inputs: SingletonPromptInputs,
|
prompt: SingletonPrompt,
|
||||||
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]:
|
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]:
|
||||||
if isinstance(inputs, str):
|
if isinstance(prompt, str):
|
||||||
return ParsedStrPrompt(type="str", content=inputs)
|
return ParsedStrPrompt(type="str", content=prompt)
|
||||||
elif isinstance(inputs, dict):
|
elif isinstance(prompt, dict):
|
||||||
if "prompt_token_ids" in inputs:
|
if "prompt_token_ids" in prompt:
|
||||||
return ParsedTokensPrompt(type="tokens",
|
return ParsedTokensPrompt(type="tokens",
|
||||||
content=inputs) # type: ignore
|
content=prompt) # type: ignore
|
||||||
elif "prompt" in inputs:
|
elif "prompt" in prompt:
|
||||||
return ParsedTextPrompt(type="text", content=inputs)
|
return ParsedTextPrompt(type="text", content=prompt)
|
||||||
|
|
||||||
raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt")
|
raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt")
|
||||||
|
|
||||||
|
|
||||||
def is_explicit_encoder_decoder_prompt(
|
def is_explicit_encoder_decoder_prompt(
|
||||||
inputs: PromptInputs) -> TypeIs[ExplicitEncoderDecoderPrompt]:
|
prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
|
||||||
return isinstance(inputs, dict) and "encoder_prompt" in inputs
|
return isinstance(prompt, dict) and "encoder_prompt" in prompt
|
||||||
|
|
||||||
|
|
||||||
def is_valid_encoder_decoder_llm_inputs(
|
def is_valid_encoder_decoder_llm_inputs(
|
||||||
|
@ -9,8 +9,8 @@ from vllm.lora.request import LoRARequest
|
|||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||||
|
|
||||||
from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptInputs,
|
from .data import (EncoderDecoderLLMInputs, LLMInputs, PromptType,
|
||||||
SingletonPromptInputs)
|
SingletonPrompt)
|
||||||
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
|
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -206,7 +206,7 @@ class InputPreprocessor:
|
|||||||
|
|
||||||
def _extract_prompt_components(
|
def _extract_prompt_components(
|
||||||
self,
|
self,
|
||||||
inputs: SingletonPromptInputs,
|
prompt: SingletonPrompt,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
) -> PromptComponents:
|
) -> PromptComponents:
|
||||||
@ -216,7 +216,7 @@ class InputPreprocessor:
|
|||||||
Arguments:
|
Arguments:
|
||||||
|
|
||||||
* request_id
|
* request_id
|
||||||
* inputs: single encoder or decoder input prompt
|
* prompt: single encoder or decoder input prompt
|
||||||
* lora_request: this is only valid for decoder prompts
|
* lora_request: this is only valid for decoder prompts
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -226,24 +226,24 @@ class InputPreprocessor:
|
|||||||
* multi_modal_data
|
* multi_modal_data
|
||||||
'''
|
'''
|
||||||
|
|
||||||
parsed = parse_singleton_prompt(inputs)
|
parsed = parse_singleton_prompt(prompt)
|
||||||
|
|
||||||
if parsed["type"] == "str":
|
if parsed["type"] == "str":
|
||||||
prompt = parsed["content"]
|
prompt_text = parsed["content"]
|
||||||
prompt_token_ids = self._tokenize_prompt(
|
prompt_token_ids = self._tokenize_prompt(
|
||||||
prompt,
|
prompt_text,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
)
|
)
|
||||||
multi_modal_data = None
|
multi_modal_data = None
|
||||||
elif parsed["type"] == "tokens":
|
elif parsed["type"] == "tokens":
|
||||||
prompt = None
|
prompt_text = None
|
||||||
prompt_token_ids = parsed["content"]["prompt_token_ids"]
|
prompt_token_ids = parsed["content"]["prompt_token_ids"]
|
||||||
multi_modal_data = parsed["content"].get("multi_modal_data")
|
multi_modal_data = parsed["content"].get("multi_modal_data")
|
||||||
elif parsed["type"] == "text":
|
elif parsed["type"] == "text":
|
||||||
prompt = parsed["content"]["prompt"]
|
prompt_text = parsed["content"]["prompt"]
|
||||||
prompt_token_ids = self._tokenize_prompt(
|
prompt_token_ids = self._tokenize_prompt(
|
||||||
prompt,
|
prompt_text,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
)
|
)
|
||||||
@ -251,33 +251,33 @@ class InputPreprocessor:
|
|||||||
else:
|
else:
|
||||||
assert_never(parsed)
|
assert_never(parsed)
|
||||||
|
|
||||||
return prompt, prompt_token_ids, multi_modal_data
|
return prompt_text, prompt_token_ids, multi_modal_data
|
||||||
|
|
||||||
async def _extract_prompt_components_async(
|
async def _extract_prompt_components_async(
|
||||||
self,
|
self,
|
||||||
inputs: SingletonPromptInputs,
|
prompt: SingletonPrompt,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
) -> PromptComponents:
|
) -> PromptComponents:
|
||||||
"""Async version of :meth:`_extract_prompt_components`."""
|
"""Async version of :meth:`_extract_prompt_components`."""
|
||||||
parsed = parse_singleton_prompt(inputs)
|
parsed = parse_singleton_prompt(prompt)
|
||||||
|
|
||||||
if parsed["type"] == "str":
|
if parsed["type"] == "str":
|
||||||
prompt = parsed["content"]
|
prompt_text = parsed["content"]
|
||||||
prompt_token_ids = await self._tokenize_prompt_async(
|
prompt_token_ids = await self._tokenize_prompt_async(
|
||||||
prompt,
|
prompt_text,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
)
|
)
|
||||||
multi_modal_data = None
|
multi_modal_data = None
|
||||||
elif parsed["type"] == "tokens":
|
elif parsed["type"] == "tokens":
|
||||||
prompt = None
|
prompt_text = None
|
||||||
prompt_token_ids = parsed["content"]["prompt_token_ids"]
|
prompt_token_ids = parsed["content"]["prompt_token_ids"]
|
||||||
multi_modal_data = parsed["content"].get("multi_modal_data")
|
multi_modal_data = parsed["content"].get("multi_modal_data")
|
||||||
elif parsed["type"] == "text":
|
elif parsed["type"] == "text":
|
||||||
prompt = parsed["content"]["prompt"]
|
prompt_text = parsed["content"]["prompt"]
|
||||||
prompt_token_ids = await self._tokenize_prompt_async(
|
prompt_token_ids = await self._tokenize_prompt_async(
|
||||||
prompt,
|
prompt_text,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
)
|
)
|
||||||
@ -285,7 +285,7 @@ class InputPreprocessor:
|
|||||||
else:
|
else:
|
||||||
assert_never(parsed)
|
assert_never(parsed)
|
||||||
|
|
||||||
return prompt, prompt_token_ids, multi_modal_data
|
return prompt_text, prompt_token_ids, multi_modal_data
|
||||||
|
|
||||||
def _build_enc_dec_llm_inputs(
|
def _build_enc_dec_llm_inputs(
|
||||||
self,
|
self,
|
||||||
@ -311,7 +311,7 @@ class InputPreprocessor:
|
|||||||
|
|
||||||
def _process_encoder_decoder_prompt(
|
def _process_encoder_decoder_prompt(
|
||||||
self,
|
self,
|
||||||
inputs: PromptInputs,
|
prompt: PromptType,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
) -> EncoderDecoderLLMInputs:
|
) -> EncoderDecoderLLMInputs:
|
||||||
'''
|
'''
|
||||||
@ -339,7 +339,7 @@ class InputPreprocessor:
|
|||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
|
||||||
* inputs: an input prompt
|
* prompt: an input prompt
|
||||||
* request_id
|
* request_id
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -350,13 +350,13 @@ class InputPreprocessor:
|
|||||||
encoder_comps: PromptComponents
|
encoder_comps: PromptComponents
|
||||||
decoder_comps: DecoderPromptComponents
|
decoder_comps: DecoderPromptComponents
|
||||||
|
|
||||||
if is_explicit_encoder_decoder_prompt(inputs):
|
if is_explicit_encoder_decoder_prompt(prompt):
|
||||||
encoder_comps = self._extract_prompt_components(
|
encoder_comps = self._extract_prompt_components(
|
||||||
inputs["encoder_prompt"],
|
prompt["encoder_prompt"],
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (decoder_input := inputs["decoder_prompt"]) is None:
|
if (decoder_input := prompt["decoder_prompt"]) is None:
|
||||||
decoder_comps = None, None, None
|
decoder_comps = None, None, None
|
||||||
else:
|
else:
|
||||||
decoder_comps = self._extract_prompt_components(
|
decoder_comps = self._extract_prompt_components(
|
||||||
@ -365,7 +365,7 @@ class InputPreprocessor:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
encoder_comps = self._extract_prompt_components(
|
encoder_comps = self._extract_prompt_components(
|
||||||
inputs,
|
prompt,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -375,20 +375,20 @@ class InputPreprocessor:
|
|||||||
|
|
||||||
async def _process_encoder_decoder_prompt_async(
|
async def _process_encoder_decoder_prompt_async(
|
||||||
self,
|
self,
|
||||||
inputs: PromptInputs,
|
prompt: PromptType,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
) -> EncoderDecoderLLMInputs:
|
) -> EncoderDecoderLLMInputs:
|
||||||
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
|
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
|
||||||
encoder_comps: PromptComponents
|
encoder_comps: PromptComponents
|
||||||
decoder_comps: DecoderPromptComponents
|
decoder_comps: DecoderPromptComponents
|
||||||
|
|
||||||
if is_explicit_encoder_decoder_prompt(inputs):
|
if is_explicit_encoder_decoder_prompt(prompt):
|
||||||
encoder_task = self._extract_prompt_components_async(
|
encoder_task = self._extract_prompt_components_async(
|
||||||
inputs["encoder_prompt"],
|
prompt["encoder_prompt"],
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (decoder_input := inputs["decoder_prompt"]) is None:
|
if (decoder_input := prompt["decoder_prompt"]) is None:
|
||||||
encoder_comps = await encoder_task
|
encoder_comps = await encoder_task
|
||||||
decoder_comps = None, None, None
|
decoder_comps = None, None, None
|
||||||
else:
|
else:
|
||||||
@ -401,7 +401,7 @@ class InputPreprocessor:
|
|||||||
encoder_task, decoder_task)
|
encoder_task, decoder_task)
|
||||||
else:
|
else:
|
||||||
encoder_comps = await self._extract_prompt_components_async(
|
encoder_comps = await self._extract_prompt_components_async(
|
||||||
inputs,
|
prompt,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -425,7 +425,7 @@ class InputPreprocessor:
|
|||||||
|
|
||||||
def _process_decoder_only_prompt(
|
def _process_decoder_only_prompt(
|
||||||
self,
|
self,
|
||||||
inputs: SingletonPromptInputs,
|
prompt: SingletonPrompt,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
@ -436,7 +436,7 @@ class InputPreprocessor:
|
|||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
|
|
||||||
* inputs: input prompt
|
* prompt: input prompt
|
||||||
* request_id
|
* request_id
|
||||||
* lora_request
|
* lora_request
|
||||||
* prompt_adapter_request
|
* prompt_adapter_request
|
||||||
@ -447,7 +447,7 @@ class InputPreprocessor:
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
prompt_comps = self._extract_prompt_components(
|
prompt_comps = self._extract_prompt_components(
|
||||||
inputs,
|
prompt,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
)
|
)
|
||||||
@ -459,14 +459,14 @@ class InputPreprocessor:
|
|||||||
|
|
||||||
async def _process_decoder_only_prompt_async(
|
async def _process_decoder_only_prompt_async(
|
||||||
self,
|
self,
|
||||||
inputs: SingletonPromptInputs,
|
prompt: SingletonPrompt,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
) -> LLMInputs:
|
) -> LLMInputs:
|
||||||
"""Async version of :meth:`_process_decoder_only_prompt`."""
|
"""Async version of :meth:`_process_decoder_only_prompt`."""
|
||||||
prompt_comps = await self._extract_prompt_components_async(
|
prompt_comps = await self._extract_prompt_components_async(
|
||||||
inputs,
|
prompt,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
)
|
)
|
||||||
@ -478,7 +478,7 @@ class InputPreprocessor:
|
|||||||
|
|
||||||
def preprocess(
|
def preprocess(
|
||||||
self,
|
self,
|
||||||
inputs: PromptInputs,
|
prompt: PromptType,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
@ -488,17 +488,17 @@ class InputPreprocessor:
|
|||||||
# Encoder-decoder model requires special mapping of
|
# Encoder-decoder model requires special mapping of
|
||||||
# input prompts to encoder & decoder
|
# input prompts to encoder & decoder
|
||||||
return self._process_encoder_decoder_prompt(
|
return self._process_encoder_decoder_prompt(
|
||||||
inputs,
|
prompt,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_explicit_encoder_decoder_prompt(inputs):
|
if is_explicit_encoder_decoder_prompt(prompt):
|
||||||
raise ValueError("Cannot pass encoder-decoder prompt "
|
raise ValueError("Cannot pass encoder-decoder prompt "
|
||||||
"to decoder-only models")
|
"to decoder-only models")
|
||||||
|
|
||||||
# Decoder-only operation
|
# Decoder-only operation
|
||||||
return self._process_decoder_only_prompt(
|
return self._process_decoder_only_prompt(
|
||||||
inputs,
|
prompt,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
prompt_adapter_request=prompt_adapter_request,
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
@ -506,7 +506,7 @@ class InputPreprocessor:
|
|||||||
|
|
||||||
async def preprocess_async(
|
async def preprocess_async(
|
||||||
self,
|
self,
|
||||||
inputs: PromptInputs,
|
prompt: PromptType,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||||
@ -516,17 +516,17 @@ class InputPreprocessor:
|
|||||||
# Encoder-decoder model requires special mapping of
|
# Encoder-decoder model requires special mapping of
|
||||||
# input prompts to encoder & decoder
|
# input prompts to encoder & decoder
|
||||||
return await self._process_encoder_decoder_prompt_async(
|
return await self._process_encoder_decoder_prompt_async(
|
||||||
inputs,
|
prompt,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_explicit_encoder_decoder_prompt(inputs):
|
if is_explicit_encoder_decoder_prompt(prompt):
|
||||||
raise ValueError("Cannot pass encoder-decoder prompt "
|
raise ValueError("Cannot pass encoder-decoder prompt "
|
||||||
"to decoder-only models")
|
"to decoder-only models")
|
||||||
|
|
||||||
# Decoder-only operation
|
# Decoder-only operation
|
||||||
return await self._process_decoder_only_prompt_async(
|
return await self._process_decoder_only_prompt_async(
|
||||||
inputs,
|
prompt,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
prompt_adapter_request=prompt_adapter_request,
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user