[Frontend] API support for beam search (#9087)
Co-authored-by: youkaichao <youkaichao@126.com>
This commit is contained in:
parent
23fea8714a
commit
168cab6bbf
@ -15,6 +15,7 @@ from vllm.engine.arg_utils import DEVICE_OPTIONS, AsyncEngineArgs, EngineArgs
|
|||||||
from vllm.entrypoints.openai.api_server import (
|
from vllm.entrypoints.openai.api_server import (
|
||||||
build_async_engine_client_from_engine_args)
|
build_async_engine_client_from_engine_args)
|
||||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
|
from vllm.sampling_params import BeamSearchParams
|
||||||
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
|
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
|
||||||
|
|
||||||
|
|
||||||
@ -145,10 +146,13 @@ def run_vllm(
|
|||||||
for prompt, input_len, _output_len in requests:
|
for prompt, input_len, _output_len in requests:
|
||||||
assert _output_len == output_len
|
assert _output_len == output_len
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
llm.beam_search(prompts,
|
llm.beam_search(
|
||||||
beam_width=n,
|
prompts,
|
||||||
max_tokens=output_len,
|
BeamSearchParams(
|
||||||
ignore_eos=True)
|
beam_width=n,
|
||||||
|
max_tokens=output_len,
|
||||||
|
ignore_eos=True,
|
||||||
|
))
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
return end - start
|
return end - start
|
||||||
|
|
||||||
|
@ -35,6 +35,7 @@ from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
|
|||||||
to_enc_dec_tuple_list, zip_enc_dec_prompts)
|
to_enc_dec_tuple_list, zip_enc_dec_prompts)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
|
from vllm.sampling_params import BeamSearchParams
|
||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
|
||||||
identity, is_cpu)
|
identity, is_cpu)
|
||||||
|
|
||||||
@ -812,7 +813,9 @@ class VllmRunner:
|
|||||||
beam_width: int,
|
beam_width: int,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
) -> List[Tuple[List[List[int]], List[str]]]:
|
) -> List[Tuple[List[List[int]], List[str]]]:
|
||||||
outputs = self.model.beam_search(prompts, beam_width, max_tokens)
|
outputs = self.model.beam_search(
|
||||||
|
prompts,
|
||||||
|
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
|
||||||
returned_outputs = []
|
returned_outputs = []
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
token_ids = [x.tokens for x in output.sequences]
|
token_ids = [x.tokens for x in output.sequences]
|
||||||
|
@ -495,25 +495,30 @@ async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
|
|||||||
assert len(batch.choices) == 2
|
assert len(batch.choices) == 2
|
||||||
assert batch.choices[0].text == batch.choices[1].text
|
assert batch.choices[0].text == batch.choices[1].text
|
||||||
|
|
||||||
# test n = 2
|
try:
|
||||||
batch = await client.completions.create(
|
# test n = 2
|
||||||
model=model_name,
|
batch = await client.completions.create(
|
||||||
prompt=prompts,
|
model=model_name,
|
||||||
n=2,
|
prompt=prompts,
|
||||||
max_tokens=5,
|
n=2,
|
||||||
temperature=0.0,
|
max_tokens=5,
|
||||||
extra_body=dict(
|
temperature=0.0,
|
||||||
# NOTE: this has to be true for n > 1 in vLLM, but not necessary
|
extra_body=dict(
|
||||||
# for official client.
|
# NOTE: this has to be true for n > 1 in vLLM, but
|
||||||
use_beam_search=True),
|
# not necessary for official client.
|
||||||
)
|
use_beam_search=True),
|
||||||
assert len(batch.choices) == 4
|
)
|
||||||
assert batch.choices[0].text != batch.choices[
|
assert len(batch.choices) == 4
|
||||||
1].text, "beam search should be different"
|
assert batch.choices[0].text != batch.choices[
|
||||||
assert batch.choices[0].text == batch.choices[
|
1].text, "beam search should be different"
|
||||||
2].text, "two copies of the same prompt should be the same"
|
assert batch.choices[0].text == batch.choices[
|
||||||
assert batch.choices[1].text == batch.choices[
|
2].text, "two copies of the same prompt should be the same"
|
||||||
3].text, "two copies of the same prompt should be the same"
|
assert batch.choices[1].text == batch.choices[
|
||||||
|
3].text, "two copies of the same prompt should be the same"
|
||||||
|
except BadRequestError as e:
|
||||||
|
# the only allowed exception is when beam search is not supported
|
||||||
|
# in the default mqllmengine
|
||||||
|
assert "--disable-frontend-multiprocessing" in str(e)
|
||||||
|
|
||||||
# test streaming
|
# test streaming
|
||||||
batch = await client.completions.create(
|
batch = await client.completions.create(
|
||||||
|
@ -14,23 +14,26 @@ from vllm.engine.arg_utils import AsyncEngineArgs
|
|||||||
from vllm.engine.async_timeout import asyncio_timeout
|
from vllm.engine.async_timeout import asyncio_timeout
|
||||||
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
|
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
|
||||||
from vllm.engine.metrics_types import StatLoggerBase
|
from vllm.engine.metrics_types import StatLoggerBase
|
||||||
|
from vllm.entrypoints.llm import BeamSearchSequence
|
||||||
from vllm.executor.executor_base import ExecutorAsyncBase
|
from vllm.executor.executor_base import ExecutorAsyncBase
|
||||||
from vllm.executor.gpu_executor import GPUExecutorAsync
|
from vllm.executor.gpu_executor import GPUExecutorAsync
|
||||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||||
from vllm.inputs import PromptType
|
from vllm.inputs import PromptType, TokensPrompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor.guided_decoding import (
|
from vllm.model_executor.guided_decoding import (
|
||||||
get_guided_decoding_logits_processor)
|
get_guided_decoding_logits_processor)
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
|
||||||
|
RequestOutput)
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||||
from vllm.sequence import ExecuteModelRequest
|
from vllm.sequence import ExecuteModelRequest
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import deprecate_kwargs, weak_bind
|
from vllm.utils import (collect_from_async_generator, deprecate_kwargs,
|
||||||
|
random_uuid, weak_bind)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
|
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
|
||||||
@ -1036,6 +1039,102 @@ class AsyncLLMEngine:
|
|||||||
):
|
):
|
||||||
yield LLMEngine.validate_output(output, RequestOutput)
|
yield LLMEngine.validate_output(output, RequestOutput)
|
||||||
|
|
||||||
|
async def beam_search(
|
||||||
|
self,
|
||||||
|
prompt: Union[PromptType, List[int]],
|
||||||
|
request_id: str,
|
||||||
|
params: BeamSearchParams,
|
||||||
|
) -> AsyncGenerator[RequestOutput, None]:
|
||||||
|
|
||||||
|
beam_width = params.beam_width
|
||||||
|
max_tokens = params.max_tokens
|
||||||
|
ignore_eos = params.ignore_eos
|
||||||
|
temperature = params.temperature
|
||||||
|
|
||||||
|
tokenizer = await self.get_tokenizer()
|
||||||
|
tokenizedPrompt = prompt if isinstance(
|
||||||
|
prompt, list) else tokenizer.encode(prompt)
|
||||||
|
tokenizedLength = len(tokenizedPrompt)
|
||||||
|
|
||||||
|
beam_search_params = SamplingParams(logprobs=2 * beam_width,
|
||||||
|
max_tokens=1,
|
||||||
|
temperature=temperature)
|
||||||
|
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
|
||||||
|
completed = []
|
||||||
|
|
||||||
|
for _ in range(max_tokens):
|
||||||
|
prompts_batch = [
|
||||||
|
TokensPrompt(prompt_token_ids=beam.tokens)
|
||||||
|
for beam in all_beams
|
||||||
|
]
|
||||||
|
|
||||||
|
tasks = []
|
||||||
|
|
||||||
|
request_id = f"beam_search-{random_uuid()}"
|
||||||
|
for i, individual_prompt in enumerate(prompts_batch):
|
||||||
|
request_id_item = f"{request_id}-{i}"
|
||||||
|
task = asyncio.create_task(
|
||||||
|
collect_from_async_generator(
|
||||||
|
self.generate(individual_prompt, beam_search_params,
|
||||||
|
request_id_item)))
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
output = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
output = [x[0] for x in output]
|
||||||
|
|
||||||
|
logger.info(output)
|
||||||
|
|
||||||
|
new_beams = []
|
||||||
|
for i, current_beam in enumerate(all_beams):
|
||||||
|
result = output[i]
|
||||||
|
|
||||||
|
if result.outputs[0].logprobs is not None:
|
||||||
|
logprobs = result.outputs[0].logprobs[0]
|
||||||
|
for token_id, logprob_obj in logprobs.items():
|
||||||
|
new_beam = BeamSearchSequence(
|
||||||
|
tokens=current_beam.tokens + [token_id],
|
||||||
|
cum_logprob=current_beam.cum_logprob +
|
||||||
|
logprob_obj.logprob)
|
||||||
|
|
||||||
|
if token_id == tokenizer.eos_token_id and \
|
||||||
|
not ignore_eos:
|
||||||
|
completed.append(new_beam)
|
||||||
|
else:
|
||||||
|
new_beams.append(new_beam)
|
||||||
|
|
||||||
|
sorted_beams = sorted(new_beams,
|
||||||
|
key=lambda x: x.cum_logprob,
|
||||||
|
reverse=True)
|
||||||
|
all_beams = sorted_beams[:beam_width]
|
||||||
|
|
||||||
|
completed.extend(all_beams)
|
||||||
|
sorted_completed = sorted(completed,
|
||||||
|
key=lambda x: x.cum_logprob,
|
||||||
|
reverse=True)
|
||||||
|
best_beams = sorted_completed[:beam_width]
|
||||||
|
|
||||||
|
for beam in best_beams:
|
||||||
|
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])
|
||||||
|
|
||||||
|
beam_search_output = RequestOutput(
|
||||||
|
request_id=request_id,
|
||||||
|
prompt=prompt,
|
||||||
|
outputs=[
|
||||||
|
CompletionOutput(
|
||||||
|
text=beam.text,
|
||||||
|
cumulative_logprob=beam.cum_logprob,
|
||||||
|
token_ids=beam.tokens,
|
||||||
|
index=i,
|
||||||
|
logprobs=beam.cum_logprob,
|
||||||
|
) for (i, beam) in enumerate(best_beams)
|
||||||
|
],
|
||||||
|
finished=True,
|
||||||
|
prompt_token_ids=tokenizedPrompt,
|
||||||
|
prompt_logprobs=None)
|
||||||
|
|
||||||
|
yield LLMEngine.validate_output(beam_search_output, RequestOutput)
|
||||||
|
|
||||||
async def encode(
|
async def encode(
|
||||||
self,
|
self,
|
||||||
prompt: PromptType,
|
prompt: PromptType,
|
||||||
|
@ -22,8 +22,8 @@ from vllm.model_executor.guided_decoding.guided_fields import (
|
|||||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sampling_params import (GuidedDecodingParams, RequestOutputKind,
|
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
|
||||||
SamplingParams)
|
RequestOutputKind, SamplingParams)
|
||||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||||
get_cached_tokenizer)
|
get_cached_tokenizer)
|
||||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||||
@ -394,10 +394,7 @@ class LLM:
|
|||||||
def beam_search(
|
def beam_search(
|
||||||
self,
|
self,
|
||||||
prompts: List[Union[str, List[int]]],
|
prompts: List[Union[str, List[int]]],
|
||||||
beam_width: int,
|
params: BeamSearchParams,
|
||||||
max_tokens: int,
|
|
||||||
ignore_eos: bool = False,
|
|
||||||
temperature: float = 0.0,
|
|
||||||
) -> List[BeamSearchOutput]:
|
) -> List[BeamSearchOutput]:
|
||||||
"""
|
"""
|
||||||
Generate sequences using beam search.
|
Generate sequences using beam search.
|
||||||
@ -405,14 +402,17 @@ class LLM:
|
|||||||
Args:
|
Args:
|
||||||
prompts: A list of prompts. Each prompt can be a string or a list
|
prompts: A list of prompts. Each prompt can be a string or a list
|
||||||
of token IDs.
|
of token IDs.
|
||||||
beam_width: The number of beams to keep at each step.
|
params: The beam search parameters.
|
||||||
max_tokens: The max number of tokens to generate for each prompt.
|
|
||||||
temperature: The temperature to use for generation.
|
|
||||||
|
|
||||||
TODO: how does beam search work together with length penalty, frequency
|
TODO: how does beam search work together with length penalty, frequency
|
||||||
penalty, and stopping criteria, etc.?
|
penalty, and stopping criteria, etc.?
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
beam_width = params.beam_width
|
||||||
|
max_tokens = params.max_tokens
|
||||||
|
temperature = params.temperature
|
||||||
|
ignore_eos = params.ignore_eos
|
||||||
|
|
||||||
tokenizer = self.get_tokenizer()
|
tokenizer = self.get_tokenizer()
|
||||||
# generate 2 * beam_width candidates at each step
|
# generate 2 * beam_width candidates at each step
|
||||||
# following the huggingface transformers implementation
|
# following the huggingface transformers implementation
|
||||||
|
@ -4,7 +4,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -21,7 +21,8 @@ class RequestLogger:
|
|||||||
request_id: str,
|
request_id: str,
|
||||||
prompt: Optional[str],
|
prompt: Optional[str],
|
||||||
prompt_token_ids: Optional[List[int]],
|
prompt_token_ids: Optional[List[int]],
|
||||||
params: Optional[Union[SamplingParams, PoolingParams]],
|
params: Optional[Union[SamplingParams, PoolingParams,
|
||||||
|
BeamSearchParams]],
|
||||||
lora_request: Optional[LoRARequest],
|
lora_request: Optional[LoRARequest],
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -11,8 +11,8 @@ from typing_extensions import Annotated, Required, TypedDict
|
|||||||
|
|
||||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.sampling_params import (GuidedDecodingParams, RequestOutputKind,
|
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
|
||||||
SamplingParams)
|
RequestOutputKind, SamplingParams)
|
||||||
from vllm.sequence import Logprob
|
from vllm.sequence import Logprob
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
@ -288,6 +288,22 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
|
|
||||||
# doc: end-chat-completion-extra-params
|
# doc: end-chat-completion-extra-params
|
||||||
|
|
||||||
|
def to_beam_search_params(self,
|
||||||
|
default_max_tokens: int) -> BeamSearchParams:
|
||||||
|
max_tokens = self.max_tokens
|
||||||
|
if max_tokens is None:
|
||||||
|
max_tokens = default_max_tokens
|
||||||
|
|
||||||
|
n = self.n if self.n is not None else 1
|
||||||
|
temperature = self.temperature if self.temperature is not None else 0.0
|
||||||
|
|
||||||
|
return BeamSearchParams(
|
||||||
|
beam_width=n,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
ignore_eos=self.ignore_eos,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
|
||||||
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
|
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
|
||||||
max_tokens = self.max_tokens
|
max_tokens = self.max_tokens
|
||||||
if max_tokens is None:
|
if max_tokens is None:
|
||||||
@ -567,6 +583,22 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
|
|
||||||
# doc: end-completion-extra-params
|
# doc: end-completion-extra-params
|
||||||
|
|
||||||
|
def to_beam_search_params(self,
|
||||||
|
default_max_tokens: int) -> BeamSearchParams:
|
||||||
|
max_tokens = self.max_tokens
|
||||||
|
if max_tokens is None:
|
||||||
|
max_tokens = default_max_tokens
|
||||||
|
|
||||||
|
n = self.n if self.n is not None else 1
|
||||||
|
temperature = self.temperature if self.temperature is not None else 0.0
|
||||||
|
|
||||||
|
return BeamSearchParams(
|
||||||
|
beam_width=n,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
ignore_eos=self.ignore_eos,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
|
||||||
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
|
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
|
||||||
max_tokens = self.max_tokens
|
max_tokens = self.max_tokens
|
||||||
if max_tokens is None:
|
if max_tokens is None:
|
||||||
|
@ -9,6 +9,7 @@ from typing import Union
|
|||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
||||||
apply_hf_chat_template,
|
apply_hf_chat_template,
|
||||||
@ -33,6 +34,7 @@ from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
|||||||
from vllm.inputs import TokensPrompt
|
from vllm.inputs import TokensPrompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import CompletionOutput, RequestOutput
|
from vllm.outputs import CompletionOutput, RequestOutput
|
||||||
|
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||||
from vllm.sequence import Logprob
|
from vllm.sequence import Logprob
|
||||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||||
log_tracing_disabled_warning)
|
log_tracing_disabled_warning)
|
||||||
@ -203,9 +205,15 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
|
|
||||||
assert prompt_inputs is not None
|
assert prompt_inputs is not None
|
||||||
|
|
||||||
sampling_params = request.to_sampling_params(
|
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||||
default_max_tokens=self.max_model_len -
|
default_max_tokens = self.max_model_len - len(
|
||||||
len(prompt_inputs["prompt_token_ids"]))
|
prompt_inputs["prompt_token_ids"])
|
||||||
|
if request.use_beam_search:
|
||||||
|
sampling_params = request.to_beam_search_params(
|
||||||
|
default_max_tokens)
|
||||||
|
else:
|
||||||
|
sampling_params = request.to_sampling_params(
|
||||||
|
default_max_tokens)
|
||||||
|
|
||||||
self._log_inputs(request_id,
|
self._log_inputs(request_id,
|
||||||
prompt_inputs,
|
prompt_inputs,
|
||||||
@ -227,15 +235,26 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
and contains_trace_headers(raw_request.headers)):
|
and contains_trace_headers(raw_request.headers)):
|
||||||
log_tracing_disabled_warning()
|
log_tracing_disabled_warning()
|
||||||
|
|
||||||
result_generator = self.engine_client.generate(
|
if isinstance(sampling_params, BeamSearchParams):
|
||||||
engine_inputs,
|
if not isinstance(self.engine_client, AsyncLLMEngine):
|
||||||
sampling_params,
|
raise ValueError(
|
||||||
request_id,
|
"Beam search in the API server is only supported with"
|
||||||
lora_request=lora_request,
|
" AsyncLLMEngine. please add "
|
||||||
trace_headers=trace_headers,
|
"`--disable-frontend-multiprocessing` to "
|
||||||
prompt_adapter_request=prompt_adapter_request,
|
"use beam search.")
|
||||||
priority=request.priority,
|
result_generator = self.engine_client.beam_search(
|
||||||
)
|
engine_inputs['prompt_token_ids'], request_id,
|
||||||
|
sampling_params)
|
||||||
|
else:
|
||||||
|
result_generator = self.engine_client.generate(
|
||||||
|
engine_inputs,
|
||||||
|
sampling_params,
|
||||||
|
request_id,
|
||||||
|
lora_request=lora_request,
|
||||||
|
trace_headers=trace_headers,
|
||||||
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
|
priority=request.priority,
|
||||||
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# TODO: Use a vllm-specific Validation Error
|
# TODO: Use a vllm-specific Validation Error
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
|
@ -8,6 +8,7 @@ from typing import Tuple, Union, cast
|
|||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.engine.protocol import EngineClient
|
from vllm.engine.protocol import EngineClient
|
||||||
from vllm.entrypoints.logger import RequestLogger
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
@ -28,6 +29,7 @@ from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
|||||||
PromptAdapterPath)
|
PromptAdapterPath)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
|
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||||
from vllm.sequence import Logprob
|
from vllm.sequence import Logprob
|
||||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||||
log_tracing_disabled_warning)
|
log_tracing_disabled_warning)
|
||||||
@ -120,9 +122,15 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
))
|
))
|
||||||
|
|
||||||
for i, prompt_inputs in enumerate(prompts):
|
for i, prompt_inputs in enumerate(prompts):
|
||||||
sampling_params = request.to_sampling_params(
|
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||||
default_max_tokens=self.max_model_len -
|
default_max_tokens = self.max_model_len - len(
|
||||||
len(prompt_inputs["prompt_token_ids"]))
|
prompt_inputs["prompt_token_ids"])
|
||||||
|
if request.use_beam_search:
|
||||||
|
sampling_params = request.to_beam_search_params(
|
||||||
|
default_max_tokens)
|
||||||
|
else:
|
||||||
|
sampling_params = request.to_sampling_params(
|
||||||
|
default_max_tokens)
|
||||||
|
|
||||||
request_id_item = f"{request_id}-{i}"
|
request_id_item = f"{request_id}-{i}"
|
||||||
|
|
||||||
@ -141,15 +149,29 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
raw_request.headers):
|
raw_request.headers):
|
||||||
log_tracing_disabled_warning()
|
log_tracing_disabled_warning()
|
||||||
|
|
||||||
generator = self.engine_client.generate(
|
if isinstance(sampling_params, BeamSearchParams):
|
||||||
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
|
if not isinstance(self.engine_client, AsyncLLMEngine):
|
||||||
sampling_params,
|
raise ValueError(
|
||||||
request_id_item,
|
"Beam search in the API server is only supported"
|
||||||
lora_request=lora_request,
|
" with AsyncLLMEngine. please add "
|
||||||
prompt_adapter_request=prompt_adapter_request,
|
"`--disable-frontend-multiprocessing` to "
|
||||||
trace_headers=trace_headers,
|
"use beam search.")
|
||||||
priority=request.priority,
|
generator = self.engine_client.beam_search(
|
||||||
)
|
prompt_inputs["prompt_token_ids"], request_id_item,
|
||||||
|
sampling_params)
|
||||||
|
else:
|
||||||
|
generator = self.engine_client.generate(
|
||||||
|
{
|
||||||
|
"prompt_token_ids":
|
||||||
|
prompt_inputs["prompt_token_ids"]
|
||||||
|
},
|
||||||
|
sampling_params,
|
||||||
|
request_id_item,
|
||||||
|
lora_request=lora_request,
|
||||||
|
prompt_adapter_request=prompt_adapter_request,
|
||||||
|
trace_headers=trace_headers,
|
||||||
|
priority=request.priority,
|
||||||
|
)
|
||||||
|
|
||||||
generators.append(generator)
|
generators.append(generator)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
@ -29,7 +29,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.pooling_params import PoolingParams
|
from vllm.pooling_params import PoolingParams
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||||
from vllm.sequence import Logprob
|
from vllm.sequence import Logprob
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||||
from vllm.utils import AtomicCounter
|
from vllm.utils import AtomicCounter
|
||||||
@ -371,7 +371,8 @@ class OpenAIServing:
|
|||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
inputs: Union[str, List[int], TextTokensPrompt],
|
inputs: Union[str, List[int], TextTokensPrompt],
|
||||||
params: Optional[Union[SamplingParams, PoolingParams]],
|
params: Optional[Union[SamplingParams, PoolingParams,
|
||||||
|
BeamSearchParams]],
|
||||||
lora_request: Optional[LoRARequest],
|
lora_request: Optional[LoRARequest],
|
||||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -530,3 +530,15 @@ class SamplingParams(
|
|||||||
f"{self.spaces_between_special_tokens}, "
|
f"{self.spaces_between_special_tokens}, "
|
||||||
f"truncate_prompt_tokens={self.truncate_prompt_tokens}), "
|
f"truncate_prompt_tokens={self.truncate_prompt_tokens}), "
|
||||||
f"guided_decoding={self.guided_decoding}")
|
f"guided_decoding={self.guided_decoding}")
|
||||||
|
|
||||||
|
|
||||||
|
class BeamSearchParams(
|
||||||
|
msgspec.Struct,
|
||||||
|
omit_defaults=True, # type: ignore[call-arg]
|
||||||
|
# required for @cached_property.
|
||||||
|
dict=True): # type: ignore[call-arg]
|
||||||
|
"""Beam search parameters for text generation."""
|
||||||
|
beam_width: int
|
||||||
|
max_tokens: int
|
||||||
|
ignore_eos: bool = False
|
||||||
|
temperature: float = 0.0
|
||||||
|
@ -504,6 +504,15 @@ async def merge_async_iterators(
|
|||||||
await it.aclose()
|
await it.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
async def collect_from_async_generator(
|
||||||
|
iterator: AsyncGenerator[T, None]) -> List[T]:
|
||||||
|
"""Collect all items from an async generator into a list."""
|
||||||
|
items = []
|
||||||
|
async for item in iterator:
|
||||||
|
items.append(item)
|
||||||
|
return items
|
||||||
|
|
||||||
|
|
||||||
def get_ip() -> str:
|
def get_ip() -> str:
|
||||||
host_ip = envs.VLLM_HOST_IP
|
host_ip = envs.VLLM_HOST_IP
|
||||||
if host_ip:
|
if host_ip:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user