[Frontend] API support for beam search (#9087)
Co-authored-by: youkaichao <youkaichao@126.com>
This commit is contained in:
parent
23fea8714a
commit
168cab6bbf
@ -15,6 +15,7 @@ from vllm.engine.arg_utils import DEVICE_OPTIONS, AsyncEngineArgs, EngineArgs
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
build_async_engine_client_from_engine_args)
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
|
||||
|
||||
|
||||
@ -145,10 +146,13 @@ def run_vllm(
|
||||
for prompt, input_len, _output_len in requests:
|
||||
assert _output_len == output_len
|
||||
start = time.perf_counter()
|
||||
llm.beam_search(prompts,
|
||||
beam_width=n,
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True)
|
||||
llm.beam_search(
|
||||
prompts,
|
||||
BeamSearchParams(
|
||||
beam_width=n,
|
||||
max_tokens=output_len,
|
||||
ignore_eos=True,
|
||||
))
|
||||
end = time.perf_counter()
|
||||
return end - start
|
||||
|
||||
|
@ -35,6 +35,7 @@ from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
|
||||
to_enc_dec_tuple_list, zip_enc_dec_prompts)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
|
||||
identity, is_cpu)
|
||||
|
||||
@ -812,7 +813,9 @@ class VllmRunner:
|
||||
beam_width: int,
|
||||
max_tokens: int,
|
||||
) -> List[Tuple[List[List[int]], List[str]]]:
|
||||
outputs = self.model.beam_search(prompts, beam_width, max_tokens)
|
||||
outputs = self.model.beam_search(
|
||||
prompts,
|
||||
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
|
||||
returned_outputs = []
|
||||
for output in outputs:
|
||||
token_ids = [x.tokens for x in output.sequences]
|
||||
|
@ -495,25 +495,30 @@ async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
|
||||
assert len(batch.choices) == 2
|
||||
assert batch.choices[0].text == batch.choices[1].text
|
||||
|
||||
# test n = 2
|
||||
batch = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompts,
|
||||
n=2,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body=dict(
|
||||
# NOTE: this has to be true for n > 1 in vLLM, but not necessary
|
||||
# for official client.
|
||||
use_beam_search=True),
|
||||
)
|
||||
assert len(batch.choices) == 4
|
||||
assert batch.choices[0].text != batch.choices[
|
||||
1].text, "beam search should be different"
|
||||
assert batch.choices[0].text == batch.choices[
|
||||
2].text, "two copies of the same prompt should be the same"
|
||||
assert batch.choices[1].text == batch.choices[
|
||||
3].text, "two copies of the same prompt should be the same"
|
||||
try:
|
||||
# test n = 2
|
||||
batch = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=prompts,
|
||||
n=2,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
extra_body=dict(
|
||||
# NOTE: this has to be true for n > 1 in vLLM, but
|
||||
# not necessary for official client.
|
||||
use_beam_search=True),
|
||||
)
|
||||
assert len(batch.choices) == 4
|
||||
assert batch.choices[0].text != batch.choices[
|
||||
1].text, "beam search should be different"
|
||||
assert batch.choices[0].text == batch.choices[
|
||||
2].text, "two copies of the same prompt should be the same"
|
||||
assert batch.choices[1].text == batch.choices[
|
||||
3].text, "two copies of the same prompt should be the same"
|
||||
except BadRequestError as e:
|
||||
# the only allowed exception is when beam search is not supported
|
||||
# in the default mqllmengine
|
||||
assert "--disable-frontend-multiprocessing" in str(e)
|
||||
|
||||
# test streaming
|
||||
batch = await client.completions.create(
|
||||
|
@ -14,23 +14,26 @@ from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_timeout import asyncio_timeout
|
||||
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
|
||||
from vllm.engine.metrics_types import StatLoggerBase
|
||||
from vllm.entrypoints.llm import BeamSearchSequence
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase
|
||||
from vllm.executor.gpu_executor import GPUExecutorAsync
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.inputs import PromptType, TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
get_guided_decoding_logits_processor)
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
|
||||
RequestOutput)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import deprecate_kwargs, weak_bind
|
||||
from vllm.utils import (collect_from_async_generator, deprecate_kwargs,
|
||||
random_uuid, weak_bind)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
|
||||
@ -1036,6 +1039,102 @@ class AsyncLLMEngine:
|
||||
):
|
||||
yield LLMEngine.validate_output(output, RequestOutput)
|
||||
|
||||
async def beam_search(
|
||||
self,
|
||||
prompt: Union[PromptType, List[int]],
|
||||
request_id: str,
|
||||
params: BeamSearchParams,
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
|
||||
beam_width = params.beam_width
|
||||
max_tokens = params.max_tokens
|
||||
ignore_eos = params.ignore_eos
|
||||
temperature = params.temperature
|
||||
|
||||
tokenizer = await self.get_tokenizer()
|
||||
tokenizedPrompt = prompt if isinstance(
|
||||
prompt, list) else tokenizer.encode(prompt)
|
||||
tokenizedLength = len(tokenizedPrompt)
|
||||
|
||||
beam_search_params = SamplingParams(logprobs=2 * beam_width,
|
||||
max_tokens=1,
|
||||
temperature=temperature)
|
||||
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
|
||||
completed = []
|
||||
|
||||
for _ in range(max_tokens):
|
||||
prompts_batch = [
|
||||
TokensPrompt(prompt_token_ids=beam.tokens)
|
||||
for beam in all_beams
|
||||
]
|
||||
|
||||
tasks = []
|
||||
|
||||
request_id = f"beam_search-{random_uuid()}"
|
||||
for i, individual_prompt in enumerate(prompts_batch):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
task = asyncio.create_task(
|
||||
collect_from_async_generator(
|
||||
self.generate(individual_prompt, beam_search_params,
|
||||
request_id_item)))
|
||||
tasks.append(task)
|
||||
|
||||
output = await asyncio.gather(*tasks)
|
||||
|
||||
output = [x[0] for x in output]
|
||||
|
||||
logger.info(output)
|
||||
|
||||
new_beams = []
|
||||
for i, current_beam in enumerate(all_beams):
|
||||
result = output[i]
|
||||
|
||||
if result.outputs[0].logprobs is not None:
|
||||
logprobs = result.outputs[0].logprobs[0]
|
||||
for token_id, logprob_obj in logprobs.items():
|
||||
new_beam = BeamSearchSequence(
|
||||
tokens=current_beam.tokens + [token_id],
|
||||
cum_logprob=current_beam.cum_logprob +
|
||||
logprob_obj.logprob)
|
||||
|
||||
if token_id == tokenizer.eos_token_id and \
|
||||
not ignore_eos:
|
||||
completed.append(new_beam)
|
||||
else:
|
||||
new_beams.append(new_beam)
|
||||
|
||||
sorted_beams = sorted(new_beams,
|
||||
key=lambda x: x.cum_logprob,
|
||||
reverse=True)
|
||||
all_beams = sorted_beams[:beam_width]
|
||||
|
||||
completed.extend(all_beams)
|
||||
sorted_completed = sorted(completed,
|
||||
key=lambda x: x.cum_logprob,
|
||||
reverse=True)
|
||||
best_beams = sorted_completed[:beam_width]
|
||||
|
||||
for beam in best_beams:
|
||||
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])
|
||||
|
||||
beam_search_output = RequestOutput(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
outputs=[
|
||||
CompletionOutput(
|
||||
text=beam.text,
|
||||
cumulative_logprob=beam.cum_logprob,
|
||||
token_ids=beam.tokens,
|
||||
index=i,
|
||||
logprobs=beam.cum_logprob,
|
||||
) for (i, beam) in enumerate(best_beams)
|
||||
],
|
||||
finished=True,
|
||||
prompt_token_ids=tokenizedPrompt,
|
||||
prompt_logprobs=None)
|
||||
|
||||
yield LLMEngine.validate_output(beam_search_output, RequestOutput)
|
||||
|
||||
async def encode(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
|
@ -22,8 +22,8 @@ from vllm.model_executor.guided_decoding.guided_fields import (
|
||||
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import (GuidedDecodingParams, RequestOutputKind,
|
||||
SamplingParams)
|
||||
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
|
||||
RequestOutputKind, SamplingParams)
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||
get_cached_tokenizer)
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
@ -394,10 +394,7 @@ class LLM:
|
||||
def beam_search(
|
||||
self,
|
||||
prompts: List[Union[str, List[int]]],
|
||||
beam_width: int,
|
||||
max_tokens: int,
|
||||
ignore_eos: bool = False,
|
||||
temperature: float = 0.0,
|
||||
params: BeamSearchParams,
|
||||
) -> List[BeamSearchOutput]:
|
||||
"""
|
||||
Generate sequences using beam search.
|
||||
@ -405,14 +402,17 @@ class LLM:
|
||||
Args:
|
||||
prompts: A list of prompts. Each prompt can be a string or a list
|
||||
of token IDs.
|
||||
beam_width: The number of beams to keep at each step.
|
||||
max_tokens: The max number of tokens to generate for each prompt.
|
||||
temperature: The temperature to use for generation.
|
||||
|
||||
params: The beam search parameters.
|
||||
|
||||
TODO: how does beam search work together with length penalty, frequency
|
||||
penalty, and stopping criteria, etc.?
|
||||
"""
|
||||
|
||||
beam_width = params.beam_width
|
||||
max_tokens = params.max_tokens
|
||||
temperature = params.temperature
|
||||
ignore_eos = params.ignore_eos
|
||||
|
||||
tokenizer = self.get_tokenizer()
|
||||
# generate 2 * beam_width candidates at each step
|
||||
# following the huggingface transformers implementation
|
||||
|
@ -4,7 +4,7 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -21,7 +21,8 @@ class RequestLogger:
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: Optional[List[int]],
|
||||
params: Optional[Union[SamplingParams, PoolingParams]],
|
||||
params: Optional[Union[SamplingParams, PoolingParams,
|
||||
BeamSearchParams]],
|
||||
lora_request: Optional[LoRARequest],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
) -> None:
|
||||
|
@ -11,8 +11,8 @@ from typing_extensions import Annotated, Required, TypedDict
|
||||
|
||||
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import (GuidedDecodingParams, RequestOutputKind,
|
||||
SamplingParams)
|
||||
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
|
||||
RequestOutputKind, SamplingParams)
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
@ -288,6 +288,22 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
|
||||
# doc: end-chat-completion-extra-params
|
||||
|
||||
def to_beam_search_params(self,
|
||||
default_max_tokens: int) -> BeamSearchParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
n = self.n if self.n is not None else 1
|
||||
temperature = self.temperature if self.temperature is not None else 0.0
|
||||
|
||||
return BeamSearchParams(
|
||||
beam_width=n,
|
||||
max_tokens=max_tokens,
|
||||
ignore_eos=self.ignore_eos,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
@ -567,6 +583,22 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
|
||||
# doc: end-completion-extra-params
|
||||
|
||||
def to_beam_search_params(self,
|
||||
default_max_tokens: int) -> BeamSearchParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
n = self.n if self.n is not None else 1
|
||||
temperature = self.temperature if self.temperature is not None else 0.0
|
||||
|
||||
return BeamSearchParams(
|
||||
beam_width=n,
|
||||
max_tokens=max_tokens,
|
||||
ignore_eos=self.ignore_eos,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
|
||||
max_tokens = self.max_tokens
|
||||
if max_tokens is None:
|
||||
|
@ -9,6 +9,7 @@ from typing import Union
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
||||
apply_hf_chat_template,
|
||||
@ -33,6 +34,7 @@ from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||
log_tracing_disabled_warning)
|
||||
@ -203,9 +205,15 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
assert prompt_inputs is not None
|
||||
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens=self.max_model_len -
|
||||
len(prompt_inputs["prompt_token_ids"]))
|
||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||
default_max_tokens = self.max_model_len - len(
|
||||
prompt_inputs["prompt_token_ids"])
|
||||
if request.use_beam_search:
|
||||
sampling_params = request.to_beam_search_params(
|
||||
default_max_tokens)
|
||||
else:
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens)
|
||||
|
||||
self._log_inputs(request_id,
|
||||
prompt_inputs,
|
||||
@ -227,15 +235,26 @@ class OpenAIServingChat(OpenAIServing):
|
||||
and contains_trace_headers(raw_request.headers)):
|
||||
log_tracing_disabled_warning()
|
||||
|
||||
result_generator = self.engine_client.generate(
|
||||
engine_inputs,
|
||||
sampling_params,
|
||||
request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=request.priority,
|
||||
)
|
||||
if isinstance(sampling_params, BeamSearchParams):
|
||||
if not isinstance(self.engine_client, AsyncLLMEngine):
|
||||
raise ValueError(
|
||||
"Beam search in the API server is only supported with"
|
||||
" AsyncLLMEngine. please add "
|
||||
"`--disable-frontend-multiprocessing` to "
|
||||
"use beam search.")
|
||||
result_generator = self.engine_client.beam_search(
|
||||
engine_inputs['prompt_token_ids'], request_id,
|
||||
sampling_params)
|
||||
else:
|
||||
result_generator = self.engine_client.generate(
|
||||
engine_inputs,
|
||||
sampling_params,
|
||||
request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=request.priority,
|
||||
)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
@ -8,6 +8,7 @@ from typing import Tuple, Union, cast
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this block
|
||||
@ -28,6 +29,7 @@ from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
||||
PromptAdapterPath)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
|
||||
log_tracing_disabled_warning)
|
||||
@ -120,9 +122,15 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
))
|
||||
|
||||
for i, prompt_inputs in enumerate(prompts):
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens=self.max_model_len -
|
||||
len(prompt_inputs["prompt_token_ids"]))
|
||||
sampling_params: Union[SamplingParams, BeamSearchParams]
|
||||
default_max_tokens = self.max_model_len - len(
|
||||
prompt_inputs["prompt_token_ids"])
|
||||
if request.use_beam_search:
|
||||
sampling_params = request.to_beam_search_params(
|
||||
default_max_tokens)
|
||||
else:
|
||||
sampling_params = request.to_sampling_params(
|
||||
default_max_tokens)
|
||||
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
@ -141,15 +149,29 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
raw_request.headers):
|
||||
log_tracing_disabled_warning()
|
||||
|
||||
generator = self.engine_client.generate(
|
||||
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
|
||||
sampling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
if isinstance(sampling_params, BeamSearchParams):
|
||||
if not isinstance(self.engine_client, AsyncLLMEngine):
|
||||
raise ValueError(
|
||||
"Beam search in the API server is only supported"
|
||||
" with AsyncLLMEngine. please add "
|
||||
"`--disable-frontend-multiprocessing` to "
|
||||
"use beam search.")
|
||||
generator = self.engine_client.beam_search(
|
||||
prompt_inputs["prompt_token_ids"], request_id_item,
|
||||
sampling_params)
|
||||
else:
|
||||
generator = self.engine_client.generate(
|
||||
{
|
||||
"prompt_token_ids":
|
||||
prompt_inputs["prompt_token_ids"]
|
||||
},
|
||||
sampling_params,
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=request.priority,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
except ValueError as e:
|
||||
|
@ -29,7 +29,7 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import AtomicCounter
|
||||
@ -371,7 +371,8 @@ class OpenAIServing:
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: Union[str, List[int], TextTokensPrompt],
|
||||
params: Optional[Union[SamplingParams, PoolingParams]],
|
||||
params: Optional[Union[SamplingParams, PoolingParams,
|
||||
BeamSearchParams]],
|
||||
lora_request: Optional[LoRARequest],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
) -> None:
|
||||
|
@ -530,3 +530,15 @@ class SamplingParams(
|
||||
f"{self.spaces_between_special_tokens}, "
|
||||
f"truncate_prompt_tokens={self.truncate_prompt_tokens}), "
|
||||
f"guided_decoding={self.guided_decoding}")
|
||||
|
||||
|
||||
class BeamSearchParams(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
# required for @cached_property.
|
||||
dict=True): # type: ignore[call-arg]
|
||||
"""Beam search parameters for text generation."""
|
||||
beam_width: int
|
||||
max_tokens: int
|
||||
ignore_eos: bool = False
|
||||
temperature: float = 0.0
|
||||
|
@ -504,6 +504,15 @@ async def merge_async_iterators(
|
||||
await it.aclose()
|
||||
|
||||
|
||||
async def collect_from_async_generator(
|
||||
iterator: AsyncGenerator[T, None]) -> List[T]:
|
||||
"""Collect all items from an async generator into a list."""
|
||||
items = []
|
||||
async for item in iterator:
|
||||
items.append(item)
|
||||
return items
|
||||
|
||||
|
||||
def get_ip() -> str:
|
||||
host_ip = envs.VLLM_HOST_IP
|
||||
if host_ip:
|
||||
|
Loading…
x
Reference in New Issue
Block a user