[V1] Avoid redundant input processing in n>1 case (#14985)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
7297941b38
commit
da6ea29f7a
@ -24,12 +24,10 @@ async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type):
|
||||
)
|
||||
lora_request = LoRARequest("1", 1, sql_lora_files)
|
||||
assert reference_tokenizer.encode("prompt") == tokenizer_group.encode(
|
||||
request_id="request_id", prompt="prompt", lora_request=lora_request)
|
||||
prompt="prompt", lora_request=lora_request)
|
||||
assert reference_tokenizer.encode(
|
||||
"prompt") == await tokenizer_group.encode_async(
|
||||
request_id="request_id",
|
||||
prompt="prompt",
|
||||
lora_request=lora_request)
|
||||
prompt="prompt", lora_request=lora_request)
|
||||
assert isinstance(tokenizer_group.get_lora_tokenizer(None),
|
||||
PreTrainedTokenizerBase)
|
||||
assert tokenizer_group.get_lora_tokenizer(
|
||||
|
@ -41,10 +41,10 @@ async def test_tokenizer_group(tokenizer_group_type):
|
||||
max_input_length=None,
|
||||
)
|
||||
assert reference_tokenizer.encode("prompt") == tokenizer_group.encode(
|
||||
request_id="request_id", prompt="prompt", lora_request=None)
|
||||
prompt="prompt", lora_request=None)
|
||||
assert reference_tokenizer.encode(
|
||||
"prompt") == await tokenizer_group.encode_async(
|
||||
request_id="request_id", prompt="prompt", lora_request=None)
|
||||
"prompt") == await tokenizer_group.encode_async(prompt="prompt",
|
||||
lora_request=None)
|
||||
assert isinstance(tokenizer_group.get_lora_tokenizer(None),
|
||||
PreTrainedTokenizerBase)
|
||||
assert tokenizer_group.get_lora_tokenizer(
|
||||
@ -69,8 +69,7 @@ async def test_tokenizer_group_pool(tokenizer_group_type):
|
||||
# and check that all requests are processed correctly.
|
||||
num_requests = tokenizer_group_pool.pool_size * 5
|
||||
requests = [
|
||||
tokenizer_group_pool.encode_async(request_id=str(i),
|
||||
prompt=f"prompt {i}",
|
||||
tokenizer_group_pool.encode_async(prompt=f"prompt {i}",
|
||||
lora_request=None)
|
||||
for i in range(num_requests)
|
||||
]
|
||||
@ -161,12 +160,8 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
|
||||
fail_at[0] = 1000
|
||||
|
||||
# We should recover successfully.
|
||||
await tokenizer_group_pool.encode_async(request_id="1",
|
||||
prompt="prompt",
|
||||
lora_request=None)
|
||||
await tokenizer_group_pool.encode_async(request_id="1",
|
||||
prompt="prompt",
|
||||
lora_request=None)
|
||||
await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None)
|
||||
await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None)
|
||||
|
||||
# Check that we have a new actor
|
||||
assert len(tokenizer_group_pool.tokenizer_actors) == len(tokenizer_actors)
|
||||
@ -184,8 +179,7 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
|
||||
|
||||
# We should fail after re-initialization.
|
||||
with pytest.raises(RuntimeError):
|
||||
await tokenizer_group_pool.encode_async(request_id="1",
|
||||
prompt="prompt",
|
||||
await tokenizer_group_pool.encode_async(prompt="prompt",
|
||||
lora_request=None)
|
||||
|
||||
# check_health should raise the same thing
|
||||
@ -206,11 +200,8 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
|
||||
|
||||
# Prompt too long error
|
||||
with pytest.raises(ValueError):
|
||||
await tokenizer_group_pool.encode_async(request_id="1",
|
||||
prompt="prompt" * 100,
|
||||
await tokenizer_group_pool.encode_async(prompt="prompt" * 100,
|
||||
lora_request=None)
|
||||
await tokenizer_group_pool.encode_async(request_id="1",
|
||||
prompt="prompt",
|
||||
lora_request=None)
|
||||
await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None)
|
||||
# Actors should stay the same.
|
||||
assert tokenizer_group_pool.tokenizer_actors == tokenizer_actors
|
||||
|
@ -492,7 +492,6 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
|
||||
preprocessed_inputs = await self.input_preprocessor.preprocess_async(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
@ -783,7 +783,6 @@ class LLMEngine:
|
||||
|
||||
preprocessed_inputs = self.input_preprocessor.preprocess(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
@ -81,10 +81,7 @@ class EngineClient(ABC):
|
||||
if is_explicit_encoder_decoder_prompt(prompt):
|
||||
raise NotImplementedError
|
||||
else:
|
||||
processed_inputs = preprocessor._prompt_to_llm_inputs(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
)
|
||||
processed_inputs = preprocessor._prompt_to_llm_inputs(prompt)
|
||||
|
||||
prompt_token_ids = processed_inputs["prompt_token_ids"]
|
||||
prompt_text = processed_inputs.get("prompt")
|
||||
|
@ -182,7 +182,6 @@ class InputPreprocessor:
|
||||
def _tokenize_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> list[int]:
|
||||
"""
|
||||
@ -202,15 +201,13 @@ class InputPreprocessor:
|
||||
"do_lower_case", False)):
|
||||
prompt = prompt.lower()
|
||||
|
||||
return tokenizer.encode(request_id=request_id,
|
||||
prompt=prompt,
|
||||
return tokenizer.encode(prompt=prompt,
|
||||
lora_request=lora_request,
|
||||
add_special_tokens=add_special_tokens)
|
||||
|
||||
async def _tokenize_prompt_async(
|
||||
self,
|
||||
prompt: str,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> list[int]:
|
||||
"""Async version of :meth:`_tokenize_prompt`."""
|
||||
@ -222,7 +219,6 @@ class InputPreprocessor:
|
||||
# appending an EOS token to the prompt which disrupts generation.
|
||||
add_special_tokens = False
|
||||
return await tokenizer.encode_async(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
lora_request=lora_request,
|
||||
add_special_tokens=add_special_tokens)
|
||||
@ -309,7 +305,6 @@ class InputPreprocessor:
|
||||
def _prompt_to_llm_inputs(
|
||||
self,
|
||||
prompt: SingletonPrompt,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
return_mm_hashes: bool = False,
|
||||
) -> SingletonInputs:
|
||||
@ -318,7 +313,6 @@ class InputPreprocessor:
|
||||
|
||||
Arguments:
|
||||
|
||||
* request_id
|
||||
* prompt: single encoder or decoder input prompt
|
||||
* lora_request: this is only valid for decoder prompts
|
||||
* return_mm_hashes: whether to return multimodal hashes
|
||||
@ -333,7 +327,6 @@ class InputPreprocessor:
|
||||
prompt_text = parsed["content"]
|
||||
prompt_token_ids = self._tokenize_prompt(
|
||||
prompt_text,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
@ -384,7 +377,6 @@ class InputPreprocessor:
|
||||
|
||||
prompt_token_ids = self._tokenize_prompt(
|
||||
prompt_text,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
@ -400,7 +392,6 @@ class InputPreprocessor:
|
||||
async def _prompt_to_llm_inputs_async(
|
||||
self,
|
||||
prompt: SingletonPrompt,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
return_mm_hashes: bool = False,
|
||||
) -> SingletonInputs:
|
||||
@ -411,7 +402,6 @@ class InputPreprocessor:
|
||||
prompt_text = parsed["content"]
|
||||
prompt_token_ids = await self._tokenize_prompt_async(
|
||||
prompt_text,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
@ -460,7 +450,6 @@ class InputPreprocessor:
|
||||
|
||||
prompt_token_ids = await self._tokenize_prompt_async(
|
||||
prompt_text,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
@ -560,7 +549,6 @@ class InputPreprocessor:
|
||||
def _process_encoder_decoder_prompt(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
request_id: str,
|
||||
) -> EncoderDecoderInputs:
|
||||
"""
|
||||
For encoder/decoder models only:
|
||||
@ -587,7 +575,6 @@ class InputPreprocessor:
|
||||
Arguments:
|
||||
|
||||
* prompt: an input prompt
|
||||
* request_id
|
||||
|
||||
Returns:
|
||||
|
||||
@ -598,16 +585,11 @@ class InputPreprocessor:
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(prompt):
|
||||
encoder_inputs = self._prompt_to_llm_inputs(
|
||||
prompt["encoder_prompt"],
|
||||
request_id=request_id,
|
||||
)
|
||||
prompt["encoder_prompt"])
|
||||
if (decoder_input := prompt["decoder_prompt"]) is None:
|
||||
decoder_inputs = None
|
||||
else:
|
||||
decoder_inputs = self._prompt_to_llm_inputs(
|
||||
decoder_input,
|
||||
request_id=request_id,
|
||||
)
|
||||
decoder_inputs = self._prompt_to_llm_inputs(decoder_input)
|
||||
# For multimodal model, override decoder prompt from processor
|
||||
# with explicit decoder prompt.
|
||||
if self.model_config.is_multimodal_model and (
|
||||
@ -616,10 +598,7 @@ class InputPreprocessor:
|
||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
encoder_inputs, decoder_inputs))
|
||||
else:
|
||||
inputs = self._prompt_to_llm_inputs(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
)
|
||||
inputs = self._prompt_to_llm_inputs(prompt)
|
||||
if self.model_config.is_multimodal_model and (
|
||||
self._can_process_multimodal()):
|
||||
# Encoder-Decoder Multimodal model
|
||||
@ -636,7 +615,6 @@ class InputPreprocessor:
|
||||
async def _process_encoder_decoder_prompt_async(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
request_id: str,
|
||||
) -> EncoderDecoderInputs:
|
||||
"""Async version of :meth:`_process_encoder_decoder_prompt`."""
|
||||
encoder_inputs: SingletonInputs
|
||||
@ -644,18 +622,13 @@ class InputPreprocessor:
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(prompt):
|
||||
encoder_task = self._prompt_to_llm_inputs_async(
|
||||
prompt["encoder_prompt"],
|
||||
request_id=request_id,
|
||||
)
|
||||
prompt["encoder_prompt"])
|
||||
|
||||
if (decoder_input := prompt["decoder_prompt"]) is None:
|
||||
encoder_inputs = await encoder_task
|
||||
decoder_inputs = None
|
||||
else:
|
||||
decoder_task = self._prompt_to_llm_inputs_async(
|
||||
decoder_input,
|
||||
request_id=request_id,
|
||||
)
|
||||
decoder_task = self._prompt_to_llm_inputs_async(decoder_input)
|
||||
|
||||
encoder_inputs, decoder_inputs = await asyncio.gather(
|
||||
encoder_task, decoder_task)
|
||||
@ -668,10 +641,7 @@ class InputPreprocessor:
|
||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
encoder_inputs, decoder_inputs))
|
||||
else:
|
||||
inputs = await self._prompt_to_llm_inputs_async(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
)
|
||||
inputs = await self._prompt_to_llm_inputs_async(prompt)
|
||||
if self.model_config.is_multimodal_model and (
|
||||
self._can_process_multimodal()):
|
||||
# Encoder-Decoder Multimodal model
|
||||
@ -704,7 +674,6 @@ class InputPreprocessor:
|
||||
def _process_decoder_only_prompt(
|
||||
self,
|
||||
prompt: SingletonPrompt,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
return_mm_hashes: bool = False,
|
||||
@ -716,7 +685,6 @@ class InputPreprocessor:
|
||||
Arguments:
|
||||
|
||||
* prompt: input prompt
|
||||
* request_id
|
||||
* lora_request
|
||||
* prompt_adapter_request
|
||||
* return_mm_hashes
|
||||
@ -728,7 +696,6 @@ class InputPreprocessor:
|
||||
|
||||
prompt_comps = self._prompt_to_llm_inputs(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
@ -741,7 +708,6 @@ class InputPreprocessor:
|
||||
async def _process_decoder_only_prompt_async(
|
||||
self,
|
||||
prompt: SingletonPrompt,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
return_mm_hashes: bool = False,
|
||||
@ -749,7 +715,6 @@ class InputPreprocessor:
|
||||
"""Async version of :meth:`_process_decoder_only_prompt`."""
|
||||
prompt_comps = await self._prompt_to_llm_inputs_async(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
)
|
||||
@ -762,7 +727,6 @@ class InputPreprocessor:
|
||||
def preprocess(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
return_mm_hashes: bool = False,
|
||||
@ -774,10 +738,7 @@ class InputPreprocessor:
|
||||
"returned until they are supported on vLLM V1.")
|
||||
# Encoder-decoder model requires special mapping of
|
||||
# input prompts to encoder & decoder
|
||||
return self._process_encoder_decoder_prompt(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
)
|
||||
return self._process_encoder_decoder_prompt(prompt)
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(prompt):
|
||||
raise ValueError("Cannot pass encoder-decoder prompt "
|
||||
@ -786,7 +747,6 @@ class InputPreprocessor:
|
||||
# Decoder-only operation
|
||||
return self._process_decoder_only_prompt(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
@ -795,7 +755,6 @@ class InputPreprocessor:
|
||||
async def preprocess_async(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
return_mm_hashes: bool = False,
|
||||
@ -807,10 +766,7 @@ class InputPreprocessor:
|
||||
"returned until they are supported on vLLM V1.")
|
||||
# Encoder-decoder model requires special mapping of
|
||||
# input prompts to encoder & decoder
|
||||
return await self._process_encoder_decoder_prompt_async(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
)
|
||||
return await self._process_encoder_decoder_prompt_async(prompt)
|
||||
|
||||
if is_explicit_encoder_decoder_prompt(prompt):
|
||||
raise ValueError("Cannot pass encoder-decoder prompt "
|
||||
@ -819,7 +775,6 @@ class InputPreprocessor:
|
||||
# Decoder-only operation
|
||||
return await self._process_decoder_only_prompt_async(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
return_mm_hashes=return_mm_hashes,
|
||||
|
@ -33,7 +33,6 @@ class BaseTokenizerGroup(ABC):
|
||||
@abstractmethod
|
||||
def encode(self,
|
||||
prompt: str,
|
||||
request_id: Optional[str] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
add_special_tokens: Optional[bool] = None) -> List[int]:
|
||||
"""Encode a prompt using the tokenizer group."""
|
||||
@ -43,7 +42,6 @@ class BaseTokenizerGroup(ABC):
|
||||
async def encode_async(
|
||||
self,
|
||||
prompt: str,
|
||||
request_id: Optional[str] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
add_special_tokens: Optional[bool] = None) -> List[int]:
|
||||
"""Encode a prompt using the tokenizer group."""
|
||||
|
@ -113,7 +113,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
|
||||
|
||||
def encode(self,
|
||||
prompt: str,
|
||||
request_id: Optional[str] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
add_special_tokens: Optional[bool] = None) -> List[int]:
|
||||
"""Encode a prompt using the tokenizer group.
|
||||
@ -133,8 +132,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
|
||||
original_actor = actor
|
||||
try:
|
||||
ret = ray.get(
|
||||
actor.encode.remote(request_id=request_id,
|
||||
prompt=prompt,
|
||||
actor.encode.remote(prompt=prompt,
|
||||
lora_request=lora_request,
|
||||
add_special_tokens=add_special_tokens))
|
||||
except ActorDiedError as e:
|
||||
@ -145,8 +143,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
|
||||
actor = self._init_actor()
|
||||
try:
|
||||
ret = ray.get(
|
||||
actor.encode.remote(request_id=request_id,
|
||||
prompt=prompt,
|
||||
actor.encode.remote(prompt=prompt,
|
||||
lora_request=lora_request,
|
||||
add_special_tokens=add_special_tokens))
|
||||
except ActorDiedError as e:
|
||||
@ -164,7 +161,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
|
||||
async def encode_async(
|
||||
self,
|
||||
prompt: str,
|
||||
request_id: Optional[str] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
add_special_tokens: Optional[bool] = None) -> List[int]:
|
||||
"""Encode a prompt using the tokenizer group.
|
||||
@ -184,7 +180,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
|
||||
original_actor = actor
|
||||
try:
|
||||
ret = await actor.encode.remote(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
lora_request=lora_request,
|
||||
add_special_tokens=add_special_tokens)
|
||||
@ -196,7 +191,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
|
||||
actor = self._init_actor()
|
||||
try:
|
||||
ret = await actor.encode.remote(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
lora_request=lora_request,
|
||||
add_special_tokens=add_special_tokens)
|
||||
|
@ -56,7 +56,6 @@ class TokenizerGroup(BaseTokenizerGroup):
|
||||
|
||||
def encode(self,
|
||||
prompt: str,
|
||||
request_id: Optional[str] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
add_special_tokens: Optional[bool] = None) -> List[int]:
|
||||
tokenizer = self.get_lora_tokenizer(lora_request)
|
||||
@ -69,7 +68,6 @@ class TokenizerGroup(BaseTokenizerGroup):
|
||||
async def encode_async(
|
||||
self,
|
||||
prompt: str,
|
||||
request_id: Optional[str] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
add_special_tokens: Optional[bool] = None) -> List[int]:
|
||||
tokenizer = await self.get_lora_tokenizer_async(lora_request)
|
||||
|
@ -4,6 +4,7 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import AsyncGenerator, Mapping
|
||||
from copy import copy
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@ -25,6 +26,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import Device, cdiv, kill_process_tree
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.output_processor import OutputProcessor
|
||||
from vllm.v1.engine.parallel_sampling import ParentRequest
|
||||
@ -177,34 +179,45 @@ class AsyncLLM(EngineClient):
|
||||
) -> asyncio.Queue[RequestOutput]:
|
||||
"""Add new request to the AsyncLLM."""
|
||||
|
||||
# 1) Create a new output queue for the request.
|
||||
# Create a new output queue for the request.
|
||||
queue: asyncio.Queue[RequestOutput] = asyncio.Queue()
|
||||
|
||||
# 2) Fan out child requests (for n>1)
|
||||
parent_req = ParentRequest.from_params(request_id, params)
|
||||
# Convert Input --> Request.
|
||||
request = self.processor.process_inputs(request_id, prompt, params,
|
||||
arrival_time, lora_request,
|
||||
trace_headers,
|
||||
prompt_adapter_request,
|
||||
priority)
|
||||
|
||||
n = params.n if isinstance(params, SamplingParams) else 1
|
||||
|
||||
if n == 1:
|
||||
await self._add_request(request, None, 0, queue)
|
||||
return queue
|
||||
|
||||
# Fan out child requests (for n>1).
|
||||
parent_request = ParentRequest(request_id, params)
|
||||
for idx in range(n):
|
||||
if parent_req is not None:
|
||||
request_id, params = parent_req.get_child_info(idx)
|
||||
|
||||
# 3) Convert Input --> Request.
|
||||
request = self.processor.process_inputs(request_id, prompt, params,
|
||||
arrival_time, lora_request,
|
||||
trace_headers,
|
||||
prompt_adapter_request,
|
||||
priority)
|
||||
|
||||
# 4) Add the request to OutputProcessor (this process).
|
||||
self.output_processor.add_request(request, parent_req, idx, queue)
|
||||
|
||||
# 5) Add the EngineCoreRequest to EngineCore (separate process).
|
||||
await self.engine_core.add_request_async(request)
|
||||
|
||||
if self.log_requests:
|
||||
logger.info("Added request %s.", request_id)
|
||||
|
||||
request_id, params = parent_request.get_child_info(idx)
|
||||
child_request = request if idx == n - 1 else copy(request)
|
||||
child_request.request_id = request_id
|
||||
child_request.sampling_params = params
|
||||
await self._add_request(child_request, parent_request, idx, queue)
|
||||
return queue
|
||||
|
||||
async def _add_request(self, request: EngineCoreRequest,
|
||||
parent_req: Optional[ParentRequest], index: int,
|
||||
queue: asyncio.Queue[RequestOutput]):
|
||||
|
||||
# Add the request to OutputProcessor (this process).
|
||||
self.output_processor.add_request(request, parent_req, index, queue)
|
||||
|
||||
# Add the EngineCoreRequest to EngineCore (separate process).
|
||||
await self.engine_core.add_request_async(request)
|
||||
|
||||
if self.log_requests:
|
||||
logger.info("Added request %s.", request.request_id)
|
||||
|
||||
# TODO: we should support multiple prompts in one call, as you
|
||||
# can do with LLM.generate. So that for multi-prompt completion
|
||||
# requests we don't need to send multiple messages to core proc,
|
||||
|
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from collections.abc import Mapping
|
||||
from copy import copy
|
||||
from typing import Optional, Union
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
@ -179,25 +180,34 @@ class LLMEngine:
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
) -> None:
|
||||
# 1) Fan out child requests (for n>1)
|
||||
parent_req = ParentRequest.from_params(request_id, params)
|
||||
# Process raw inputs into the request.
|
||||
request = self.processor.process_inputs(request_id, prompt, params,
|
||||
arrival_time, lora_request,
|
||||
trace_headers,
|
||||
prompt_adapter_request,
|
||||
priority)
|
||||
|
||||
n = params.n if isinstance(params, SamplingParams) else 1
|
||||
for idx in range(n):
|
||||
if parent_req is not None:
|
||||
request_id, params = parent_req.get_child_info(idx)
|
||||
|
||||
# 2) Process raw inputs into the request.
|
||||
request = self.processor.process_inputs(request_id, prompt, params,
|
||||
arrival_time, lora_request,
|
||||
trace_headers,
|
||||
prompt_adapter_request,
|
||||
priority)
|
||||
|
||||
# 3) Make a new RequestState and queue.
|
||||
self.output_processor.add_request(request, parent_req, idx)
|
||||
|
||||
# 3) Add the request to EngineCore.
|
||||
if n == 1:
|
||||
# Make a new RequestState and queue.
|
||||
self.output_processor.add_request(request, None, 0)
|
||||
# Add the request to EngineCore.
|
||||
self.engine_core.add_request(request)
|
||||
return
|
||||
|
||||
# Fan out child requests (for n>1).
|
||||
parent_req = ParentRequest(request_id, params)
|
||||
for idx in range(n):
|
||||
request_id, params = parent_req.get_child_info(idx)
|
||||
child_request = request if idx == n - 1 else copy(request)
|
||||
child_request.request_id = request_id
|
||||
child_request.sampling_params = params
|
||||
|
||||
# Make a new RequestState and queue.
|
||||
self.output_processor.add_request(child_request, parent_req, idx)
|
||||
# Add the request to EngineCore.
|
||||
self.engine_core.add_request(child_request)
|
||||
|
||||
def step(self) -> list[RequestOutput]:
|
||||
|
||||
|
@ -1,10 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from copy import copy
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
from vllm.outputs import CompletionOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.v1.metrics.stats import IterationStats
|
||||
|
||||
@ -43,16 +42,6 @@ class ParentRequest:
|
||||
self.max_num_generation_tokens = 0
|
||||
self.cached_child_sampling_params = None
|
||||
|
||||
@classmethod
|
||||
def from_params(
|
||||
cls,
|
||||
request_id: str,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
) -> Optional['ParentRequest']:
|
||||
if not isinstance(params, SamplingParams) or params.n == 1:
|
||||
return None
|
||||
return cls(request_id, params)
|
||||
|
||||
def _get_child_sampling_params(
|
||||
self,
|
||||
index: int,
|
||||
|
@ -173,7 +173,6 @@ class Processor:
|
||||
# 3. Apply prompt adapter to prompt token ids if one exists.
|
||||
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
return_mm_hashes=self.use_hash,
|
||||
|
Loading…
x
Reference in New Issue
Block a user