[V1] Avoid redundant input processing in n>1 case (#14985)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-03-20 22:24:10 -07:00 committed by GitHub
parent 7297941b38
commit da6ea29f7a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 85 additions and 145 deletions

View File

@ -24,12 +24,10 @@ async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type):
) )
lora_request = LoRARequest("1", 1, sql_lora_files) lora_request = LoRARequest("1", 1, sql_lora_files)
assert reference_tokenizer.encode("prompt") == tokenizer_group.encode( assert reference_tokenizer.encode("prompt") == tokenizer_group.encode(
request_id="request_id", prompt="prompt", lora_request=lora_request) prompt="prompt", lora_request=lora_request)
assert reference_tokenizer.encode( assert reference_tokenizer.encode(
"prompt") == await tokenizer_group.encode_async( "prompt") == await tokenizer_group.encode_async(
request_id="request_id", prompt="prompt", lora_request=lora_request)
prompt="prompt",
lora_request=lora_request)
assert isinstance(tokenizer_group.get_lora_tokenizer(None), assert isinstance(tokenizer_group.get_lora_tokenizer(None),
PreTrainedTokenizerBase) PreTrainedTokenizerBase)
assert tokenizer_group.get_lora_tokenizer( assert tokenizer_group.get_lora_tokenizer(

View File

@ -41,10 +41,10 @@ async def test_tokenizer_group(tokenizer_group_type):
max_input_length=None, max_input_length=None,
) )
assert reference_tokenizer.encode("prompt") == tokenizer_group.encode( assert reference_tokenizer.encode("prompt") == tokenizer_group.encode(
request_id="request_id", prompt="prompt", lora_request=None) prompt="prompt", lora_request=None)
assert reference_tokenizer.encode( assert reference_tokenizer.encode(
"prompt") == await tokenizer_group.encode_async( "prompt") == await tokenizer_group.encode_async(prompt="prompt",
request_id="request_id", prompt="prompt", lora_request=None) lora_request=None)
assert isinstance(tokenizer_group.get_lora_tokenizer(None), assert isinstance(tokenizer_group.get_lora_tokenizer(None),
PreTrainedTokenizerBase) PreTrainedTokenizerBase)
assert tokenizer_group.get_lora_tokenizer( assert tokenizer_group.get_lora_tokenizer(
@ -69,8 +69,7 @@ async def test_tokenizer_group_pool(tokenizer_group_type):
# and check that all requests are processed correctly. # and check that all requests are processed correctly.
num_requests = tokenizer_group_pool.pool_size * 5 num_requests = tokenizer_group_pool.pool_size * 5
requests = [ requests = [
tokenizer_group_pool.encode_async(request_id=str(i), tokenizer_group_pool.encode_async(prompt=f"prompt {i}",
prompt=f"prompt {i}",
lora_request=None) lora_request=None)
for i in range(num_requests) for i in range(num_requests)
] ]
@ -161,12 +160,8 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
fail_at[0] = 1000 fail_at[0] = 1000
# We should recover successfully. # We should recover successfully.
await tokenizer_group_pool.encode_async(request_id="1", await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None)
prompt="prompt", await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None)
lora_request=None)
await tokenizer_group_pool.encode_async(request_id="1",
prompt="prompt",
lora_request=None)
# Check that we have a new actor # Check that we have a new actor
assert len(tokenizer_group_pool.tokenizer_actors) == len(tokenizer_actors) assert len(tokenizer_group_pool.tokenizer_actors) == len(tokenizer_actors)
@ -184,8 +179,7 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
# We should fail after re-initialization. # We should fail after re-initialization.
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
await tokenizer_group_pool.encode_async(request_id="1", await tokenizer_group_pool.encode_async(prompt="prompt",
prompt="prompt",
lora_request=None) lora_request=None)
# check_health should raise the same thing # check_health should raise the same thing
@ -206,11 +200,8 @@ async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
# Prompt too long error # Prompt too long error
with pytest.raises(ValueError): with pytest.raises(ValueError):
await tokenizer_group_pool.encode_async(request_id="1", await tokenizer_group_pool.encode_async(prompt="prompt" * 100,
prompt="prompt" * 100,
lora_request=None) lora_request=None)
await tokenizer_group_pool.encode_async(request_id="1", await tokenizer_group_pool.encode_async(prompt="prompt", lora_request=None)
prompt="prompt",
lora_request=None)
# Actors should stay the same. # Actors should stay the same.
assert tokenizer_group_pool.tokenizer_actors == tokenizer_actors assert tokenizer_group_pool.tokenizer_actors == tokenizer_actors

View File

@ -492,7 +492,6 @@ class _AsyncLLMEngine(LLMEngine):
preprocessed_inputs = await self.input_preprocessor.preprocess_async( preprocessed_inputs = await self.input_preprocessor.preprocess_async(
prompt, prompt,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
) )

View File

@ -783,7 +783,6 @@ class LLMEngine:
preprocessed_inputs = self.input_preprocessor.preprocess( preprocessed_inputs = self.input_preprocessor.preprocess(
prompt, prompt,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
) )

View File

@ -81,10 +81,7 @@ class EngineClient(ABC):
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
raise NotImplementedError raise NotImplementedError
else: else:
processed_inputs = preprocessor._prompt_to_llm_inputs( processed_inputs = preprocessor._prompt_to_llm_inputs(prompt)
prompt,
request_id=request_id,
)
prompt_token_ids = processed_inputs["prompt_token_ids"] prompt_token_ids = processed_inputs["prompt_token_ids"]
prompt_text = processed_inputs.get("prompt") prompt_text = processed_inputs.get("prompt")

View File

@ -182,7 +182,6 @@ class InputPreprocessor:
def _tokenize_prompt( def _tokenize_prompt(
self, self,
prompt: str, prompt: str,
request_id: str,
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
) -> list[int]: ) -> list[int]:
""" """
@ -202,15 +201,13 @@ class InputPreprocessor:
"do_lower_case", False)): "do_lower_case", False)):
prompt = prompt.lower() prompt = prompt.lower()
return tokenizer.encode(request_id=request_id, return tokenizer.encode(prompt=prompt,
prompt=prompt,
lora_request=lora_request, lora_request=lora_request,
add_special_tokens=add_special_tokens) add_special_tokens=add_special_tokens)
async def _tokenize_prompt_async( async def _tokenize_prompt_async(
self, self,
prompt: str, prompt: str,
request_id: str,
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
) -> list[int]: ) -> list[int]:
"""Async version of :meth:`_tokenize_prompt`.""" """Async version of :meth:`_tokenize_prompt`."""
@ -222,7 +219,6 @@ class InputPreprocessor:
# appending an EOS token to the prompt which disrupts generation. # appending an EOS token to the prompt which disrupts generation.
add_special_tokens = False add_special_tokens = False
return await tokenizer.encode_async( return await tokenizer.encode_async(
request_id=request_id,
prompt=prompt, prompt=prompt,
lora_request=lora_request, lora_request=lora_request,
add_special_tokens=add_special_tokens) add_special_tokens=add_special_tokens)
@ -309,7 +305,6 @@ class InputPreprocessor:
def _prompt_to_llm_inputs( def _prompt_to_llm_inputs(
self, self,
prompt: SingletonPrompt, prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
) -> SingletonInputs: ) -> SingletonInputs:
@ -318,7 +313,6 @@ class InputPreprocessor:
Arguments: Arguments:
* request_id
* prompt: single encoder or decoder input prompt * prompt: single encoder or decoder input prompt
* lora_request: this is only valid for decoder prompts * lora_request: this is only valid for decoder prompts
* return_mm_hashes: whether to return multimodal hashes * return_mm_hashes: whether to return multimodal hashes
@ -333,7 +327,6 @@ class InputPreprocessor:
prompt_text = parsed["content"] prompt_text = parsed["content"]
prompt_token_ids = self._tokenize_prompt( prompt_token_ids = self._tokenize_prompt(
prompt_text, prompt_text,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
@ -384,7 +377,6 @@ class InputPreprocessor:
prompt_token_ids = self._tokenize_prompt( prompt_token_ids = self._tokenize_prompt(
prompt_text, prompt_text,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
@ -400,7 +392,6 @@ class InputPreprocessor:
async def _prompt_to_llm_inputs_async( async def _prompt_to_llm_inputs_async(
self, self,
prompt: SingletonPrompt, prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
) -> SingletonInputs: ) -> SingletonInputs:
@ -411,7 +402,6 @@ class InputPreprocessor:
prompt_text = parsed["content"] prompt_text = parsed["content"]
prompt_token_ids = await self._tokenize_prompt_async( prompt_token_ids = await self._tokenize_prompt_async(
prompt_text, prompt_text,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
@ -460,7 +450,6 @@ class InputPreprocessor:
prompt_token_ids = await self._tokenize_prompt_async( prompt_token_ids = await self._tokenize_prompt_async(
prompt_text, prompt_text,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
) )
@ -560,7 +549,6 @@ class InputPreprocessor:
def _process_encoder_decoder_prompt( def _process_encoder_decoder_prompt(
self, self,
prompt: PromptType, prompt: PromptType,
request_id: str,
) -> EncoderDecoderInputs: ) -> EncoderDecoderInputs:
""" """
For encoder/decoder models only: For encoder/decoder models only:
@ -587,7 +575,6 @@ class InputPreprocessor:
Arguments: Arguments:
* prompt: an input prompt * prompt: an input prompt
* request_id
Returns: Returns:
@ -598,16 +585,11 @@ class InputPreprocessor:
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
encoder_inputs = self._prompt_to_llm_inputs( encoder_inputs = self._prompt_to_llm_inputs(
prompt["encoder_prompt"], prompt["encoder_prompt"])
request_id=request_id,
)
if (decoder_input := prompt["decoder_prompt"]) is None: if (decoder_input := prompt["decoder_prompt"]) is None:
decoder_inputs = None decoder_inputs = None
else: else:
decoder_inputs = self._prompt_to_llm_inputs( decoder_inputs = self._prompt_to_llm_inputs(decoder_input)
decoder_input,
request_id=request_id,
)
# For multimodal model, override decoder prompt from processor # For multimodal model, override decoder prompt from processor
# with explicit decoder prompt. # with explicit decoder prompt.
if self.model_config.is_multimodal_model and ( if self.model_config.is_multimodal_model and (
@ -616,10 +598,7 @@ class InputPreprocessor:
self._separate_enc_dec_inputs_from_mm_processor_outputs( self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs)) encoder_inputs, decoder_inputs))
else: else:
inputs = self._prompt_to_llm_inputs( inputs = self._prompt_to_llm_inputs(prompt)
prompt,
request_id=request_id,
)
if self.model_config.is_multimodal_model and ( if self.model_config.is_multimodal_model and (
self._can_process_multimodal()): self._can_process_multimodal()):
# Encoder-Decoder Multimodal model # Encoder-Decoder Multimodal model
@ -636,7 +615,6 @@ class InputPreprocessor:
async def _process_encoder_decoder_prompt_async( async def _process_encoder_decoder_prompt_async(
self, self,
prompt: PromptType, prompt: PromptType,
request_id: str,
) -> EncoderDecoderInputs: ) -> EncoderDecoderInputs:
"""Async version of :meth:`_process_encoder_decoder_prompt`.""" """Async version of :meth:`_process_encoder_decoder_prompt`."""
encoder_inputs: SingletonInputs encoder_inputs: SingletonInputs
@ -644,18 +622,13 @@ class InputPreprocessor:
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
encoder_task = self._prompt_to_llm_inputs_async( encoder_task = self._prompt_to_llm_inputs_async(
prompt["encoder_prompt"], prompt["encoder_prompt"])
request_id=request_id,
)
if (decoder_input := prompt["decoder_prompt"]) is None: if (decoder_input := prompt["decoder_prompt"]) is None:
encoder_inputs = await encoder_task encoder_inputs = await encoder_task
decoder_inputs = None decoder_inputs = None
else: else:
decoder_task = self._prompt_to_llm_inputs_async( decoder_task = self._prompt_to_llm_inputs_async(decoder_input)
decoder_input,
request_id=request_id,
)
encoder_inputs, decoder_inputs = await asyncio.gather( encoder_inputs, decoder_inputs = await asyncio.gather(
encoder_task, decoder_task) encoder_task, decoder_task)
@ -668,10 +641,7 @@ class InputPreprocessor:
self._separate_enc_dec_inputs_from_mm_processor_outputs( self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs)) encoder_inputs, decoder_inputs))
else: else:
inputs = await self._prompt_to_llm_inputs_async( inputs = await self._prompt_to_llm_inputs_async(prompt)
prompt,
request_id=request_id,
)
if self.model_config.is_multimodal_model and ( if self.model_config.is_multimodal_model and (
self._can_process_multimodal()): self._can_process_multimodal()):
# Encoder-Decoder Multimodal model # Encoder-Decoder Multimodal model
@ -704,7 +674,6 @@ class InputPreprocessor:
def _process_decoder_only_prompt( def _process_decoder_only_prompt(
self, self,
prompt: SingletonPrompt, prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
@ -716,7 +685,6 @@ class InputPreprocessor:
Arguments: Arguments:
* prompt: input prompt * prompt: input prompt
* request_id
* lora_request * lora_request
* prompt_adapter_request * prompt_adapter_request
* return_mm_hashes * return_mm_hashes
@ -728,7 +696,6 @@ class InputPreprocessor:
prompt_comps = self._prompt_to_llm_inputs( prompt_comps = self._prompt_to_llm_inputs(
prompt, prompt,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
return_mm_hashes=return_mm_hashes, return_mm_hashes=return_mm_hashes,
) )
@ -741,7 +708,6 @@ class InputPreprocessor:
async def _process_decoder_only_prompt_async( async def _process_decoder_only_prompt_async(
self, self,
prompt: SingletonPrompt, prompt: SingletonPrompt,
request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
@ -749,7 +715,6 @@ class InputPreprocessor:
"""Async version of :meth:`_process_decoder_only_prompt`.""" """Async version of :meth:`_process_decoder_only_prompt`."""
prompt_comps = await self._prompt_to_llm_inputs_async( prompt_comps = await self._prompt_to_llm_inputs_async(
prompt, prompt,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
return_mm_hashes=return_mm_hashes, return_mm_hashes=return_mm_hashes,
) )
@ -762,7 +727,6 @@ class InputPreprocessor:
def preprocess( def preprocess(
self, self,
prompt: PromptType, prompt: PromptType,
request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
@ -774,10 +738,7 @@ class InputPreprocessor:
"returned until they are supported on vLLM V1.") "returned until they are supported on vLLM V1.")
# Encoder-decoder model requires special mapping of # Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder # input prompts to encoder & decoder
return self._process_encoder_decoder_prompt( return self._process_encoder_decoder_prompt(prompt)
prompt,
request_id=request_id,
)
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
raise ValueError("Cannot pass encoder-decoder prompt " raise ValueError("Cannot pass encoder-decoder prompt "
@ -786,7 +747,6 @@ class InputPreprocessor:
# Decoder-only operation # Decoder-only operation
return self._process_decoder_only_prompt( return self._process_decoder_only_prompt(
prompt, prompt,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
return_mm_hashes=return_mm_hashes, return_mm_hashes=return_mm_hashes,
@ -795,7 +755,6 @@ class InputPreprocessor:
async def preprocess_async( async def preprocess_async(
self, self,
prompt: PromptType, prompt: PromptType,
request_id: str,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
return_mm_hashes: bool = False, return_mm_hashes: bool = False,
@ -807,10 +766,7 @@ class InputPreprocessor:
"returned until they are supported on vLLM V1.") "returned until they are supported on vLLM V1.")
# Encoder-decoder model requires special mapping of # Encoder-decoder model requires special mapping of
# input prompts to encoder & decoder # input prompts to encoder & decoder
return await self._process_encoder_decoder_prompt_async( return await self._process_encoder_decoder_prompt_async(prompt)
prompt,
request_id=request_id,
)
if is_explicit_encoder_decoder_prompt(prompt): if is_explicit_encoder_decoder_prompt(prompt):
raise ValueError("Cannot pass encoder-decoder prompt " raise ValueError("Cannot pass encoder-decoder prompt "
@ -819,7 +775,6 @@ class InputPreprocessor:
# Decoder-only operation # Decoder-only operation
return await self._process_decoder_only_prompt_async( return await self._process_decoder_only_prompt_async(
prompt, prompt,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
return_mm_hashes=return_mm_hashes, return_mm_hashes=return_mm_hashes,

View File

@ -33,7 +33,6 @@ class BaseTokenizerGroup(ABC):
@abstractmethod @abstractmethod
def encode(self, def encode(self,
prompt: str, prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]: add_special_tokens: Optional[bool] = None) -> List[int]:
"""Encode a prompt using the tokenizer group.""" """Encode a prompt using the tokenizer group."""
@ -43,7 +42,6 @@ class BaseTokenizerGroup(ABC):
async def encode_async( async def encode_async(
self, self,
prompt: str, prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]: add_special_tokens: Optional[bool] = None) -> List[int]:
"""Encode a prompt using the tokenizer group.""" """Encode a prompt using the tokenizer group."""

View File

@ -113,7 +113,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
def encode(self, def encode(self,
prompt: str, prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]: add_special_tokens: Optional[bool] = None) -> List[int]:
"""Encode a prompt using the tokenizer group. """Encode a prompt using the tokenizer group.
@ -133,8 +132,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
original_actor = actor original_actor = actor
try: try:
ret = ray.get( ret = ray.get(
actor.encode.remote(request_id=request_id, actor.encode.remote(prompt=prompt,
prompt=prompt,
lora_request=lora_request, lora_request=lora_request,
add_special_tokens=add_special_tokens)) add_special_tokens=add_special_tokens))
except ActorDiedError as e: except ActorDiedError as e:
@ -145,8 +143,7 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
actor = self._init_actor() actor = self._init_actor()
try: try:
ret = ray.get( ret = ray.get(
actor.encode.remote(request_id=request_id, actor.encode.remote(prompt=prompt,
prompt=prompt,
lora_request=lora_request, lora_request=lora_request,
add_special_tokens=add_special_tokens)) add_special_tokens=add_special_tokens))
except ActorDiedError as e: except ActorDiedError as e:
@ -164,7 +161,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
async def encode_async( async def encode_async(
self, self,
prompt: str, prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]: add_special_tokens: Optional[bool] = None) -> List[int]:
"""Encode a prompt using the tokenizer group. """Encode a prompt using the tokenizer group.
@ -184,7 +180,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
original_actor = actor original_actor = actor
try: try:
ret = await actor.encode.remote( ret = await actor.encode.remote(
request_id=request_id,
prompt=prompt, prompt=prompt,
lora_request=lora_request, lora_request=lora_request,
add_special_tokens=add_special_tokens) add_special_tokens=add_special_tokens)
@ -196,7 +191,6 @@ class RayTokenizerGroupPool(BaseTokenizerGroup):
actor = self._init_actor() actor = self._init_actor()
try: try:
ret = await actor.encode.remote( ret = await actor.encode.remote(
request_id=request_id,
prompt=prompt, prompt=prompt,
lora_request=lora_request, lora_request=lora_request,
add_special_tokens=add_special_tokens) add_special_tokens=add_special_tokens)

View File

@ -56,7 +56,6 @@ class TokenizerGroup(BaseTokenizerGroup):
def encode(self, def encode(self,
prompt: str, prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]: add_special_tokens: Optional[bool] = None) -> List[int]:
tokenizer = self.get_lora_tokenizer(lora_request) tokenizer = self.get_lora_tokenizer(lora_request)
@ -69,7 +68,6 @@ class TokenizerGroup(BaseTokenizerGroup):
async def encode_async( async def encode_async(
self, self,
prompt: str, prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
add_special_tokens: Optional[bool] = None) -> List[int]: add_special_tokens: Optional[bool] = None) -> List[int]:
tokenizer = await self.get_lora_tokenizer_async(lora_request) tokenizer = await self.get_lora_tokenizer_async(lora_request)

View File

@ -4,6 +4,7 @@ import asyncio
import logging import logging
import os import os
from collections.abc import AsyncGenerator, Mapping from collections.abc import AsyncGenerator, Mapping
from copy import copy
from typing import Optional, Union from typing import Optional, Union
import numpy as np import numpy as np
@ -25,6 +26,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device, cdiv, kill_process_tree from vllm.utils import Device, cdiv, kill_process_tree
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.parallel_sampling import ParentRequest
@ -177,34 +179,45 @@ class AsyncLLM(EngineClient):
) -> asyncio.Queue[RequestOutput]: ) -> asyncio.Queue[RequestOutput]:
"""Add new request to the AsyncLLM.""" """Add new request to the AsyncLLM."""
# 1) Create a new output queue for the request. # Create a new output queue for the request.
queue: asyncio.Queue[RequestOutput] = asyncio.Queue() queue: asyncio.Queue[RequestOutput] = asyncio.Queue()
# 2) Fan out child requests (for n>1) # Convert Input --> Request.
parent_req = ParentRequest.from_params(request_id, params) request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
trace_headers,
prompt_adapter_request,
priority)
n = params.n if isinstance(params, SamplingParams) else 1 n = params.n if isinstance(params, SamplingParams) else 1
if n == 1:
await self._add_request(request, None, 0, queue)
return queue
# Fan out child requests (for n>1).
parent_request = ParentRequest(request_id, params)
for idx in range(n): for idx in range(n):
if parent_req is not None: request_id, params = parent_request.get_child_info(idx)
request_id, params = parent_req.get_child_info(idx) child_request = request if idx == n - 1 else copy(request)
child_request.request_id = request_id
# 3) Convert Input --> Request. child_request.sampling_params = params
request = self.processor.process_inputs(request_id, prompt, params, await self._add_request(child_request, parent_request, idx, queue)
arrival_time, lora_request,
trace_headers,
prompt_adapter_request,
priority)
# 4) Add the request to OutputProcessor (this process).
self.output_processor.add_request(request, parent_req, idx, queue)
# 5) Add the EngineCoreRequest to EngineCore (separate process).
await self.engine_core.add_request_async(request)
if self.log_requests:
logger.info("Added request %s.", request_id)
return queue return queue
async def _add_request(self, request: EngineCoreRequest,
parent_req: Optional[ParentRequest], index: int,
queue: asyncio.Queue[RequestOutput]):
# Add the request to OutputProcessor (this process).
self.output_processor.add_request(request, parent_req, index, queue)
# Add the EngineCoreRequest to EngineCore (separate process).
await self.engine_core.add_request_async(request)
if self.log_requests:
logger.info("Added request %s.", request.request_id)
# TODO: we should support multiple prompts in one call, as you # TODO: we should support multiple prompts in one call, as you
# can do with LLM.generate. So that for multi-prompt completion # can do with LLM.generate. So that for multi-prompt completion
# requests we don't need to send multiple messages to core proc, # requests we don't need to send multiple messages to core proc,

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from collections.abc import Mapping from collections.abc import Mapping
from copy import copy
from typing import Optional, Union from typing import Optional, Union
from typing_extensions import TypeVar from typing_extensions import TypeVar
@ -179,25 +180,34 @@ class LLMEngine:
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0, priority: int = 0,
) -> None: ) -> None:
# 1) Fan out child requests (for n>1) # Process raw inputs into the request.
parent_req = ParentRequest.from_params(request_id, params) request = self.processor.process_inputs(request_id, prompt, params,
arrival_time, lora_request,
trace_headers,
prompt_adapter_request,
priority)
n = params.n if isinstance(params, SamplingParams) else 1 n = params.n if isinstance(params, SamplingParams) else 1
for idx in range(n):
if parent_req is not None:
request_id, params = parent_req.get_child_info(idx)
# 2) Process raw inputs into the request. if n == 1:
request = self.processor.process_inputs(request_id, prompt, params, # Make a new RequestState and queue.
arrival_time, lora_request, self.output_processor.add_request(request, None, 0)
trace_headers, # Add the request to EngineCore.
prompt_adapter_request,
priority)
# 3) Make a new RequestState and queue.
self.output_processor.add_request(request, parent_req, idx)
# 3) Add the request to EngineCore.
self.engine_core.add_request(request) self.engine_core.add_request(request)
return
# Fan out child requests (for n>1).
parent_req = ParentRequest(request_id, params)
for idx in range(n):
request_id, params = parent_req.get_child_info(idx)
child_request = request if idx == n - 1 else copy(request)
child_request.request_id = request_id
child_request.sampling_params = params
# Make a new RequestState and queue.
self.output_processor.add_request(child_request, parent_req, idx)
# Add the request to EngineCore.
self.engine_core.add_request(child_request)
def step(self) -> list[RequestOutput]: def step(self) -> list[RequestOutput]:

View File

@ -1,10 +1,9 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from copy import copy from copy import copy
from typing import Optional, Union from typing import Optional
from vllm.outputs import CompletionOutput from vllm.outputs import CompletionOutput
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.v1.metrics.stats import IterationStats from vllm.v1.metrics.stats import IterationStats
@ -43,16 +42,6 @@ class ParentRequest:
self.max_num_generation_tokens = 0 self.max_num_generation_tokens = 0
self.cached_child_sampling_params = None self.cached_child_sampling_params = None
@classmethod
def from_params(
cls,
request_id: str,
params: Union[SamplingParams, PoolingParams],
) -> Optional['ParentRequest']:
if not isinstance(params, SamplingParams) or params.n == 1:
return None
return cls(request_id, params)
def _get_child_sampling_params( def _get_child_sampling_params(
self, self,
index: int, index: int,

View File

@ -173,7 +173,6 @@ class Processor:
# 3. Apply prompt adapter to prompt token ids if one exists. # 3. Apply prompt adapter to prompt token ids if one exists.
processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess( processed_inputs: ProcessorInputs = self.input_preprocessor.preprocess(
prompt, prompt,
request_id=request_id,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
return_mm_hashes=self.use_hash, return_mm_hashes=self.use_hash,