Reinstate best_of
for V0 (#14356)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
151b08e0fe
commit
bf0560bda9
@ -25,6 +25,14 @@ def test_n_gt_1(model):
|
|||||||
assert len(outputs[0].outputs) == 3
|
assert len(outputs[0].outputs) == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_best_of(model):
|
||||||
|
"""Raise a ValueError since best_of is deprecated."""
|
||||||
|
|
||||||
|
params = SamplingParams(n=2, best_of=3)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_ = model.generate(PROMPT, params)
|
||||||
|
|
||||||
|
|
||||||
def test_penalties(model):
|
def test_penalties(model):
|
||||||
"""Check that we do not get errors if applied."""
|
"""Check that we do not get errors if applied."""
|
||||||
|
|
||||||
|
@ -97,7 +97,11 @@ class LLM:
|
|||||||
throughput. However, if the value is too high, it may cause out-of-
|
throughput. However, if the value is too high, it may cause out-of-
|
||||||
memory (OOM) errors.
|
memory (OOM) errors.
|
||||||
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
|
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
|
||||||
Too small values may cause out-of-memory (OOM) errors.
|
This can be used for temporarily storing the states of the requests
|
||||||
|
when their `best_of` sampling parameters are larger than 1. If all
|
||||||
|
requests will have `best_of=1`, you can safely set this to 0.
|
||||||
|
Noting that `best_of` is only supported in V0. Otherwise, too small
|
||||||
|
values may cause out-of-memory (OOM) errors.
|
||||||
cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
|
cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
|
||||||
the model weights. This virtually increases the GPU memory space
|
the model weights. This virtually increases the GPU memory space
|
||||||
you can use to hold the model weights, at the cost of CPU-GPU data
|
you can use to hold the model weights, at the cost of CPU-GPU data
|
||||||
|
@ -242,6 +242,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
user: Optional[str] = None
|
user: Optional[str] = None
|
||||||
|
|
||||||
# doc: begin-chat-completion-sampling-params
|
# doc: begin-chat-completion-sampling-params
|
||||||
|
best_of: Optional[int] = None
|
||||||
use_beam_search: bool = False
|
use_beam_search: bool = False
|
||||||
top_k: Optional[int] = None
|
top_k: Optional[int] = None
|
||||||
min_p: Optional[float] = None
|
min_p: Optional[float] = None
|
||||||
@ -478,6 +479,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
|
|
||||||
return SamplingParams.from_optional(
|
return SamplingParams.from_optional(
|
||||||
n=self.n,
|
n=self.n,
|
||||||
|
best_of=self.best_of,
|
||||||
presence_penalty=self.presence_penalty,
|
presence_penalty=self.presence_penalty,
|
||||||
frequency_penalty=self.frequency_penalty,
|
frequency_penalty=self.frequency_penalty,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
@ -648,6 +650,7 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
# https://platform.openai.com/docs/api-reference/completions/create
|
# https://platform.openai.com/docs/api-reference/completions/create
|
||||||
model: Optional[str] = None
|
model: Optional[str] = None
|
||||||
prompt: Union[list[int], list[list[int]], str, list[str]]
|
prompt: Union[list[int], list[list[int]], str, list[str]]
|
||||||
|
best_of: Optional[int] = None
|
||||||
echo: Optional[bool] = False
|
echo: Optional[bool] = False
|
||||||
frequency_penalty: Optional[float] = 0.0
|
frequency_penalty: Optional[float] = 0.0
|
||||||
logit_bias: Optional[dict[str, float]] = None
|
logit_bias: Optional[dict[str, float]] = None
|
||||||
@ -845,6 +848,7 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
|
|
||||||
return SamplingParams.from_optional(
|
return SamplingParams.from_optional(
|
||||||
n=self.n,
|
n=self.n,
|
||||||
|
best_of=self.best_of,
|
||||||
presence_penalty=self.presence_penalty,
|
presence_penalty=self.presence_penalty,
|
||||||
frequency_penalty=self.frequency_penalty,
|
frequency_penalty=self.frequency_penalty,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
|
@ -168,8 +168,12 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
model_name = self._get_model_name(request.model, lora_request)
|
model_name = self._get_model_name(request.model, lora_request)
|
||||||
num_prompts = len(engine_prompts)
|
num_prompts = len(engine_prompts)
|
||||||
|
|
||||||
# We do not stream the results when use beam search.
|
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||||
stream = (request.stream and not request.use_beam_search)
|
# results. Noting that best_of is only supported in V0. In addition,
|
||||||
|
# we do not stream the results when use beam search.
|
||||||
|
stream = (request.stream
|
||||||
|
and (request.best_of is None or request.n == request.best_of)
|
||||||
|
and not request.use_beam_search)
|
||||||
|
|
||||||
# Streaming response
|
# Streaming response
|
||||||
if stream:
|
if stream:
|
||||||
|
@ -116,6 +116,10 @@ class SamplingParams(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
n: Number of output sequences to return for the given prompt.
|
n: Number of output sequences to return for the given prompt.
|
||||||
|
best_of: Number of output sequences that are generated from the prompt.
|
||||||
|
From these `best_of` sequences, the top `n` sequences are returned.
|
||||||
|
`best_of` must be greater than or equal to `n`. By default,
|
||||||
|
`best_of` is set to `n`. Warning, this is only supported in V0.
|
||||||
presence_penalty: Float that penalizes new tokens based on whether they
|
presence_penalty: Float that penalizes new tokens based on whether they
|
||||||
appear in the generated text so far. Values > 0 encourage the model
|
appear in the generated text so far. Values > 0 encourage the model
|
||||||
to use new tokens, while values < 0 encourage the model to repeat
|
to use new tokens, while values < 0 encourage the model to repeat
|
||||||
@ -183,6 +187,7 @@ class SamplingParams(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
n: int = 1
|
n: int = 1
|
||||||
|
best_of: Optional[int] = None
|
||||||
_real_n: Optional[int] = None
|
_real_n: Optional[int] = None
|
||||||
presence_penalty: float = 0.0
|
presence_penalty: float = 0.0
|
||||||
frequency_penalty: float = 0.0
|
frequency_penalty: float = 0.0
|
||||||
@ -226,6 +231,7 @@ class SamplingParams(
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def from_optional(
|
def from_optional(
|
||||||
n: Optional[int] = 1,
|
n: Optional[int] = 1,
|
||||||
|
best_of: Optional[int] = None,
|
||||||
presence_penalty: Optional[float] = 0.0,
|
presence_penalty: Optional[float] = 0.0,
|
||||||
frequency_penalty: Optional[float] = 0.0,
|
frequency_penalty: Optional[float] = 0.0,
|
||||||
repetition_penalty: Optional[float] = 1.0,
|
repetition_penalty: Optional[float] = 1.0,
|
||||||
@ -264,6 +270,7 @@ class SamplingParams(
|
|||||||
|
|
||||||
return SamplingParams(
|
return SamplingParams(
|
||||||
n=1 if n is None else n,
|
n=1 if n is None else n,
|
||||||
|
best_of=best_of,
|
||||||
presence_penalty=0.0
|
presence_penalty=0.0
|
||||||
if presence_penalty is None else presence_penalty,
|
if presence_penalty is None else presence_penalty,
|
||||||
frequency_penalty=0.0
|
frequency_penalty=0.0
|
||||||
@ -296,6 +303,20 @@ class SamplingParams(
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
|
# how we deal with `best_of``:
|
||||||
|
# if `best_of`` is not set, we default to `n`;
|
||||||
|
# if `best_of`` is set, we set `n`` to `best_of`,
|
||||||
|
# and set `_real_n`` to the original `n`.
|
||||||
|
# when we return the result, we will check
|
||||||
|
# if we need to return `n` or `_real_n` results
|
||||||
|
if self.best_of:
|
||||||
|
if self.best_of < self.n:
|
||||||
|
raise ValueError(
|
||||||
|
f"best_of must be greater than or equal to n, "
|
||||||
|
f"got n={self.n} and best_of={self.best_of}.")
|
||||||
|
if not self._real_n:
|
||||||
|
self._real_n = self.n
|
||||||
|
self.n = self.best_of
|
||||||
|
|
||||||
if 0 < self.temperature < _MAX_TEMP:
|
if 0 < self.temperature < _MAX_TEMP:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -402,6 +423,9 @@ class SamplingParams(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"stop strings are only supported when detokenize is True. "
|
"stop strings are only supported when detokenize is True. "
|
||||||
"Set detokenize=True to use stop.")
|
"Set detokenize=True to use stop.")
|
||||||
|
if self.best_of != self._real_n and self.output_kind == (
|
||||||
|
RequestOutputKind.DELTA):
|
||||||
|
raise ValueError("best_of must equal n to use output_kind=DELTA")
|
||||||
|
|
||||||
def _verify_greedy_sampling(self) -> None:
|
def _verify_greedy_sampling(self) -> None:
|
||||||
if self.n > 1:
|
if self.n > 1:
|
||||||
|
@ -93,6 +93,9 @@ class Processor:
|
|||||||
self,
|
self,
|
||||||
params: SamplingParams,
|
params: SamplingParams,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
# Best of not yet supported.
|
||||||
|
if params.best_of is not None and params.best_of > 1:
|
||||||
|
raise ValueError("VLLM V1 does not yet support best_of.")
|
||||||
# Bad words not yet supported.
|
# Bad words not yet supported.
|
||||||
if params.bad_words:
|
if params.bad_words:
|
||||||
raise ValueError("VLLM V1 does not yet support bad_words.")
|
raise ValueError("VLLM V1 does not yet support bad_words.")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user