Deprecate best_of Sampling Parameter in anticipation for vLLM V1 (#13997)

Signed-off-by: vincent-4 <vincentzhongy+githubvincent4@gmail.com>
Signed-off-by: Brayden Zhong <b8zhong@uwaterloo.ca>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Brayden Zhong <b8zhong@uwaterloo.ca>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Vincent 2025-03-05 15:22:43 -05:00 committed by GitHub
parent a32c8669ca
commit a4f1ee35d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 16 additions and 88 deletions

View File

@ -27,7 +27,6 @@ class RequestFuncInput:
output_len: int
model: str
model_name: Optional[str] = None
best_of: int = 1
logprobs: Optional[int] = None
extra_body: Optional[dict] = None
multi_modal_content: Optional[dict] = None
@ -58,7 +57,6 @@ async def async_request_tgi(
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
params = {
"best_of": request_func_input.best_of,
"max_new_tokens": request_func_input.output_len,
"do_sample": True,
"temperature": 0.01, # TGI does not accept 0.0 temperature.
@ -130,7 +128,6 @@ async def async_request_trt_llm(
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
assert request_func_input.best_of == 1
payload = {
"accumulate_tokens": True,
"text_input": request_func_input.prompt,
@ -195,7 +192,6 @@ async def async_request_deepspeed_mii(
) -> RequestFuncOutput:
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
assert request_func_input.best_of == 1
payload = {
"prompt": request_func_input.prompt,
@ -249,7 +245,6 @@ async def async_request_openai_completions(
if request_func_input.model_name else request_func_input.model,
"prompt": request_func_input.prompt,
"temperature": 0.0,
"best_of": request_func_input.best_of,
"max_tokens": request_func_input.output_len,
"logprobs": request_func_input.logprobs,
"stream": True,

View File

@ -560,7 +560,6 @@ async def benchmark(
tokenizer: PreTrainedTokenizerBase,
input_requests: list[tuple[str, int, int]],
logprobs: Optional[int],
best_of: int,
request_rate: float,
burstiness: float,
disable_tqdm: bool,
@ -592,7 +591,6 @@ async def benchmark(
prompt_len=test_prompt_len,
output_len=test_output_len,
logprobs=logprobs,
best_of=best_of,
multi_modal_content=test_mm_content,
ignore_eos=ignore_eos,
)
@ -619,7 +617,6 @@ async def benchmark(
prompt_len=test_prompt_len,
output_len=test_output_len,
logprobs=logprobs,
best_of=best_of,
multi_modal_content=test_mm_content,
ignore_eos=ignore_eos)
profile_output = await request_func(request_func_input=profile_input)
@ -668,7 +665,6 @@ async def benchmark(
prompt_len=prompt_len,
output_len=output_len,
logprobs=logprobs,
best_of=best_of,
multi_modal_content=mm_content,
ignore_eos=ignore_eos)
tasks.append(
@ -686,7 +682,6 @@ async def benchmark(
prompt_len=test_prompt_len,
output_len=test_output_len,
logprobs=logprobs,
best_of=best_of,
)
profile_output = await request_func(request_func_input=profile_input)
if profile_output.success:
@ -958,7 +953,6 @@ def main(args: argparse.Namespace):
tokenizer=tokenizer,
input_requests=input_requests,
logprobs=args.logprobs,
best_of=args.best_of,
request_rate=args.request_rate,
burstiness=args.burstiness,
disable_tqdm=args.disable_tqdm,
@ -983,7 +977,6 @@ def main(args: argparse.Namespace):
result_json["backend"] = backend
result_json["model_id"] = model_id
result_json["tokenizer_id"] = tokenizer_id
result_json["best_of"] = args.best_of
result_json["num_prompts"] = args.num_prompts
# Metadata
@ -1081,13 +1074,6 @@ if __name__ == "__main__":
help=
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
)
parser.add_argument(
"--best-of",
type=int,
default=1,
help="Generates `best_of` sequences per prompt and "
"returns the best one.",
)
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument(
"--num-prompts",

View File

@ -15,7 +15,6 @@ def create_test_prompts() -> list[tuple[str, SamplingParams]]:
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
("What is the meaning of life?",
SamplingParams(n=2,
best_of=5,
temperature=0.8,
top_p=0.95,
frequency_penalty=0.1)),

View File

@ -28,7 +28,6 @@ with tracer.start_as_current_span("client-span", kind=SpanKind.CLIENT) as span:
"model": "facebook/opt-125m",
"prompt": prompt,
"max_tokens": 10,
"best_of": 20,
"n": 3,
"use_beam_search": "true",
"temperature": 0.0,

View File

@ -617,7 +617,6 @@ def test_schedule_decode_blocks_to_copy_update():
num_gpu_blocks=16)
_, seq_group = create_dummy_prompt("1",
prompt_length=60,
best_of=2,
block_size=block_size)
curr_loras = None
scheduler._allocate_and_set_running(seq_group)
@ -686,7 +685,6 @@ def test_schedule_swapped_cannot_swap_in():
for i in range(2):
_, seq_group = create_dummy_prompt(str(i),
prompt_length=60,
best_of=2,
block_size=block_size)
scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
@ -717,7 +715,6 @@ def test_infeasible_swap():
for i in range(2):
_, seq_group = create_dummy_prompt(str(i),
prompt_length=60,
best_of=2,
block_size=block_size)
scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
@ -747,7 +744,6 @@ def test_schedule_swapped_blocks_to_copy():
curr_loras = None
_, seq_group = create_dummy_prompt("1",
prompt_length=60,
best_of=2,
block_size=block_size)
scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)

View File

@ -18,7 +18,6 @@ def create_dummy_prompt(
prompt_length: int = -1,
block_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
best_of: int = 1,
prompt_tokens: Optional[list[int]] = None,
min_tokens: int = 0,
max_tokens: int = 16,
@ -32,17 +31,19 @@ def create_dummy_prompt(
prompt_tokens = list(range(prompt_length))
prompt_str = " ".join([str(t) for t in prompt_tokens])
prompt = Sequence(int(request_id),
inputs=token_inputs(prompt_tokens, prompt=prompt_str),
block_size=block_size)
seq_group = SequenceGroup(request_id=request_id,
seqs=[prompt],
arrival_time=time.time(),
sampling_params=SamplingParams(
best_of=best_of,
max_tokens=max_tokens,
min_tokens=min_tokens),
lora_request=lora_request)
prompt = Sequence(
int(request_id),
inputs=token_inputs(prompt_tokens, prompt=prompt_str),
block_size=block_size,
)
seq_group = SequenceGroup(
request_id=request_id,
seqs=[prompt],
arrival_time=time.time(),
sampling_params=SamplingParams(max_tokens=max_tokens,
min_tokens=min_tokens),
lora_request=lora_request,
)
return prompt, seq_group
@ -72,7 +73,6 @@ def create_dummy_prompt_encoder_decoder(
encoder_prompt_length: int,
block_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
best_of: int = 1,
) -> tuple[Sequence, Sequence, SequenceGroup]:
if not block_size:
block_size = decoder_prompt_length
@ -102,7 +102,6 @@ def create_dummy_prompt_encoder_decoder(
seq_group = SequenceGroup(request_id=request_id,
seqs=[decoder_prompt],
sampling_params=SamplingParams(best_of=best_of),
arrival_time=time.time(),
lora_request=lora_request,
encoder_seq=encoder_prompt)

View File

@ -25,14 +25,6 @@ def test_n_gt_1(model):
assert len(outputs[0].outputs) == 3
def test_best_of(model):
"""Raise a ValueError since best_of is deprecated."""
params = SamplingParams(n=2, best_of=3)
with pytest.raises(ValueError):
_ = model.generate(PROMPT, params)
def test_penalties(model):
"""Check that we do not get errors if applied."""

View File

@ -97,10 +97,7 @@ class LLM:
throughput. However, if the value is too high, it may cause out-of-
memory (OOM) errors.
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
This can be used for temporarily storing the states of the requests
when their `best_of` sampling parameters are larger than 1. If all
requests will have `best_of=1`, you can safely set this to 0.
Otherwise, too small values may cause out-of-memory (OOM) errors.
Too small values may cause out-of-memory (OOM) errors.
cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
the model weights. This virtually increases the GPU memory space
you can use to hold the model weights, at the cost of CPU-GPU data

View File

@ -242,7 +242,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
user: Optional[str] = None
# doc: begin-chat-completion-sampling-params
best_of: Optional[int] = None
use_beam_search: bool = False
top_k: Optional[int] = None
min_p: Optional[float] = None
@ -479,7 +478,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
return SamplingParams.from_optional(
n=self.n,
best_of=self.best_of,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=repetition_penalty,
@ -650,7 +648,6 @@ class CompletionRequest(OpenAIBaseModel):
# https://platform.openai.com/docs/api-reference/completions/create
model: Optional[str] = None
prompt: Union[list[int], list[list[int]], str, list[str]]
best_of: Optional[int] = None
echo: Optional[bool] = False
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[dict[str, float]] = None
@ -848,7 +845,6 @@ class CompletionRequest(OpenAIBaseModel):
return SamplingParams.from_optional(
n=self.n,
best_of=self.best_of,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=repetition_penalty,

View File

@ -168,12 +168,8 @@ class OpenAIServingCompletion(OpenAIServing):
model_name = self._get_model_name(request.model, lora_request)
num_prompts = len(engine_prompts)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use
# beam search.
stream = (request.stream
and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search)
# We do not stream the results when use beam search.
stream = (request.stream and not request.use_beam_search)
# Streaming response
if stream:

View File

@ -116,10 +116,6 @@ class SamplingParams(
Args:
n: Number of output sequences to return for the given prompt.
best_of: Number of output sequences that are generated from the prompt.
From these `best_of` sequences, the top `n` sequences are returned.
`best_of` must be greater than or equal to `n`. By default,
`best_of` is set to `n`.
presence_penalty: Float that penalizes new tokens based on whether they
appear in the generated text so far. Values > 0 encourage the model
to use new tokens, while values < 0 encourage the model to repeat
@ -187,7 +183,6 @@ class SamplingParams(
"""
n: int = 1
best_of: Optional[int] = None
_real_n: Optional[int] = None
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
@ -231,7 +226,6 @@ class SamplingParams(
@staticmethod
def from_optional(
n: Optional[int] = 1,
best_of: Optional[int] = None,
presence_penalty: Optional[float] = 0.0,
frequency_penalty: Optional[float] = 0.0,
repetition_penalty: Optional[float] = 1.0,
@ -270,7 +264,6 @@ class SamplingParams(
return SamplingParams(
n=1 if n is None else n,
best_of=best_of,
presence_penalty=0.0
if presence_penalty is None else presence_penalty,
frequency_penalty=0.0
@ -303,20 +296,6 @@ class SamplingParams(
)
def __post_init__(self) -> None:
# how we deal with `best_of``:
# if `best_of`` is not set, we default to `n`;
# if `best_of`` is set, we set `n`` to `best_of`,
# and set `_real_n`` to the original `n`.
# when we return the result, we will check
# if we need to return `n` or `_real_n` results
if self.best_of:
if self.best_of < self.n:
raise ValueError(
f"best_of must be greater than or equal to n, "
f"got n={self.n} and best_of={self.best_of}.")
if not self._real_n:
self._real_n = self.n
self.n = self.best_of
if 0 < self.temperature < _MAX_TEMP:
logger.warning(
@ -423,9 +402,6 @@ class SamplingParams(
raise ValueError(
"stop strings are only supported when detokenize is True. "
"Set detokenize=True to use stop.")
if self.best_of != self._real_n and self.output_kind == (
RequestOutputKind.DELTA):
raise ValueError("best_of must equal n to use output_kind=DELTA")
def _verify_greedy_sampling(self) -> None:
if self.n > 1:

View File

@ -93,9 +93,6 @@ class Processor:
self,
params: SamplingParams,
) -> None:
# Best of not yet supported.
if params.best_of:
raise ValueError("VLLM V1 does not yet support best_of.")
# Bad words not yet supported.
if params.bad_words:
raise ValueError("VLLM V1 does not yet support bad_words.")