Reinstate best_of for V0 (#14356)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-03-06 17:34:22 +01:00 committed by GitHub
parent 151b08e0fe
commit bf0560bda9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 50 additions and 3 deletions

View File

@ -25,6 +25,14 @@ def test_n_gt_1(model):
assert len(outputs[0].outputs) == 3
def test_best_of(model):
"""Raise a ValueError since best_of is deprecated."""
params = SamplingParams(n=2, best_of=3)
with pytest.raises(ValueError):
_ = model.generate(PROMPT, params)
def test_penalties(model):
"""Check that we do not get errors if applied."""

View File

@ -97,7 +97,11 @@ class LLM:
throughput. However, if the value is too high, it may cause out-of-
memory (OOM) errors.
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
Too small values may cause out-of-memory (OOM) errors.
This can be used for temporarily storing the states of the requests
when their `best_of` sampling parameters are larger than 1. If all
requests will have `best_of=1`, you can safely set this to 0.
Noting that `best_of` is only supported in V0. Otherwise, too small
values may cause out-of-memory (OOM) errors.
cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
the model weights. This virtually increases the GPU memory space
you can use to hold the model weights, at the cost of CPU-GPU data

View File

@ -242,6 +242,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
user: Optional[str] = None
# doc: begin-chat-completion-sampling-params
best_of: Optional[int] = None
use_beam_search: bool = False
top_k: Optional[int] = None
min_p: Optional[float] = None
@ -478,6 +479,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
return SamplingParams.from_optional(
n=self.n,
best_of=self.best_of,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=repetition_penalty,
@ -648,6 +650,7 @@ class CompletionRequest(OpenAIBaseModel):
# https://platform.openai.com/docs/api-reference/completions/create
model: Optional[str] = None
prompt: Union[list[int], list[list[int]], str, list[str]]
best_of: Optional[int] = None
echo: Optional[bool] = False
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[dict[str, float]] = None
@ -845,6 +848,7 @@ class CompletionRequest(OpenAIBaseModel):
return SamplingParams.from_optional(
n=self.n,
best_of=self.best_of,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=repetition_penalty,

View File

@ -168,8 +168,12 @@ class OpenAIServingCompletion(OpenAIServing):
model_name = self._get_model_name(request.model, lora_request)
num_prompts = len(engine_prompts)
# We do not stream the results when use beam search.
stream = (request.stream and not request.use_beam_search)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. Noting that best_of is only supported in V0. In addition,
# we do not stream the results when use beam search.
stream = (request.stream
and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search)
# Streaming response
if stream:

View File

@ -116,6 +116,10 @@ class SamplingParams(
Args:
n: Number of output sequences to return for the given prompt.
best_of: Number of output sequences that are generated from the prompt.
From these `best_of` sequences, the top `n` sequences are returned.
`best_of` must be greater than or equal to `n`. By default,
`best_of` is set to `n`. Warning, this is only supported in V0.
presence_penalty: Float that penalizes new tokens based on whether they
appear in the generated text so far. Values > 0 encourage the model
to use new tokens, while values < 0 encourage the model to repeat
@ -183,6 +187,7 @@ class SamplingParams(
"""
n: int = 1
best_of: Optional[int] = None
_real_n: Optional[int] = None
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
@ -226,6 +231,7 @@ class SamplingParams(
@staticmethod
def from_optional(
n: Optional[int] = 1,
best_of: Optional[int] = None,
presence_penalty: Optional[float] = 0.0,
frequency_penalty: Optional[float] = 0.0,
repetition_penalty: Optional[float] = 1.0,
@ -264,6 +270,7 @@ class SamplingParams(
return SamplingParams(
n=1 if n is None else n,
best_of=best_of,
presence_penalty=0.0
if presence_penalty is None else presence_penalty,
frequency_penalty=0.0
@ -296,6 +303,20 @@ class SamplingParams(
)
def __post_init__(self) -> None:
# how we deal with `best_of``:
# if `best_of`` is not set, we default to `n`;
# if `best_of`` is set, we set `n`` to `best_of`,
# and set `_real_n`` to the original `n`.
# when we return the result, we will check
# if we need to return `n` or `_real_n` results
if self.best_of:
if self.best_of < self.n:
raise ValueError(
f"best_of must be greater than or equal to n, "
f"got n={self.n} and best_of={self.best_of}.")
if not self._real_n:
self._real_n = self.n
self.n = self.best_of
if 0 < self.temperature < _MAX_TEMP:
logger.warning(
@ -402,6 +423,9 @@ class SamplingParams(
raise ValueError(
"stop strings are only supported when detokenize is True. "
"Set detokenize=True to use stop.")
if self.best_of != self._real_n and self.output_kind == (
RequestOutputKind.DELTA):
raise ValueError("best_of must equal n to use output_kind=DELTA")
def _verify_greedy_sampling(self) -> None:
if self.n > 1:

View File

@ -93,6 +93,9 @@ class Processor:
self,
params: SamplingParams,
) -> None:
# Best of not yet supported.
if params.best_of is not None and params.best_of > 1:
raise ValueError("VLLM V1 does not yet support best_of.")
# Bad words not yet supported.
if params.bad_words:
raise ValueError("VLLM V1 does not yet support bad_words.")