From a4f1ee35d6864d305cd0521607743f80e7ad5d6b Mon Sep 17 00:00:00 2001 From: Vincent Date: Wed, 5 Mar 2025 15:22:43 -0500 Subject: [PATCH] Deprecate `best_of` Sampling Parameter in anticipation for vLLM V1 (#13997) Signed-off-by: vincent-4 Signed-off-by: Brayden Zhong Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Brayden Zhong Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- benchmarks/backend_request_func.py | 5 ---- benchmarks/benchmark_serving.py | 14 ---------- .../offline_inference/llm_engine_example.py | 1 - .../opentelemetry/dummy_client.py | 1 - tests/core/test_scheduler.py | 4 --- tests/core/utils.py | 27 +++++++++---------- tests/v1/sample/test_sampling_params_e2e.py | 8 ------ vllm/entrypoints/llm.py | 5 +--- vllm/entrypoints/openai/protocol.py | 4 --- vllm/entrypoints/openai/serving_completion.py | 8 ++---- vllm/sampling_params.py | 24 ----------------- vllm/v1/engine/processor.py | 3 --- 12 files changed, 16 insertions(+), 88 deletions(-) diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 15870576..d53428d2 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -27,7 +27,6 @@ class RequestFuncInput: output_len: int model: str model_name: Optional[str] = None - best_of: int = 1 logprobs: Optional[int] = None extra_body: Optional[dict] = None multi_modal_content: Optional[dict] = None @@ -58,7 +57,6 @@ async def async_request_tgi( async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session: params = { - "best_of": request_func_input.best_of, "max_new_tokens": request_func_input.output_len, "do_sample": True, "temperature": 0.01, # TGI does not accept 0.0 temperature. @@ -130,7 +128,6 @@ async def async_request_trt_llm( async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session: - assert request_func_input.best_of == 1 payload = { "accumulate_tokens": True, "text_input": request_func_input.prompt, @@ -195,7 +192,6 @@ async def async_request_deepspeed_mii( ) -> RequestFuncOutput: async with aiohttp.ClientSession(trust_env=True, timeout=AIOHTTP_TIMEOUT) as session: - assert request_func_input.best_of == 1 payload = { "prompt": request_func_input.prompt, @@ -249,7 +245,6 @@ async def async_request_openai_completions( if request_func_input.model_name else request_func_input.model, "prompt": request_func_input.prompt, "temperature": 0.0, - "best_of": request_func_input.best_of, "max_tokens": request_func_input.output_len, "logprobs": request_func_input.logprobs, "stream": True, diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 16ec0a48..68ca2dc8 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -560,7 +560,6 @@ async def benchmark( tokenizer: PreTrainedTokenizerBase, input_requests: list[tuple[str, int, int]], logprobs: Optional[int], - best_of: int, request_rate: float, burstiness: float, disable_tqdm: bool, @@ -592,7 +591,6 @@ async def benchmark( prompt_len=test_prompt_len, output_len=test_output_len, logprobs=logprobs, - best_of=best_of, multi_modal_content=test_mm_content, ignore_eos=ignore_eos, ) @@ -619,7 +617,6 @@ async def benchmark( prompt_len=test_prompt_len, output_len=test_output_len, logprobs=logprobs, - best_of=best_of, multi_modal_content=test_mm_content, ignore_eos=ignore_eos) profile_output = await request_func(request_func_input=profile_input) @@ -668,7 +665,6 @@ async def benchmark( prompt_len=prompt_len, output_len=output_len, logprobs=logprobs, - best_of=best_of, multi_modal_content=mm_content, ignore_eos=ignore_eos) tasks.append( @@ -686,7 +682,6 @@ async def benchmark( prompt_len=test_prompt_len, output_len=test_output_len, logprobs=logprobs, - best_of=best_of, ) profile_output = await request_func(request_func_input=profile_input) if profile_output.success: @@ -958,7 +953,6 @@ def main(args: argparse.Namespace): tokenizer=tokenizer, input_requests=input_requests, logprobs=args.logprobs, - best_of=args.best_of, request_rate=args.request_rate, burstiness=args.burstiness, disable_tqdm=args.disable_tqdm, @@ -983,7 +977,6 @@ def main(args: argparse.Namespace): result_json["backend"] = backend result_json["model_id"] = model_id result_json["tokenizer_id"] = tokenizer_id - result_json["best_of"] = args.best_of result_json["num_prompts"] = args.num_prompts # Metadata @@ -1081,13 +1074,6 @@ if __name__ == "__main__": help= "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 ) - parser.add_argument( - "--best-of", - type=int, - default=1, - help="Generates `best_of` sequences per prompt and " - "returns the best one.", - ) parser.add_argument("--use-beam-search", action="store_true") parser.add_argument( "--num-prompts", diff --git a/examples/offline_inference/llm_engine_example.py b/examples/offline_inference/llm_engine_example.py index f7741a37..e94f47b7 100644 --- a/examples/offline_inference/llm_engine_example.py +++ b/examples/offline_inference/llm_engine_example.py @@ -15,7 +15,6 @@ def create_test_prompts() -> list[tuple[str, SamplingParams]]: SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)), ("What is the meaning of life?", SamplingParams(n=2, - best_of=5, temperature=0.8, top_p=0.95, frequency_penalty=0.1)), diff --git a/examples/online_serving/opentelemetry/dummy_client.py b/examples/online_serving/opentelemetry/dummy_client.py index 7a605f85..a8b35309 100644 --- a/examples/online_serving/opentelemetry/dummy_client.py +++ b/examples/online_serving/opentelemetry/dummy_client.py @@ -28,7 +28,6 @@ with tracer.start_as_current_span("client-span", kind=SpanKind.CLIENT) as span: "model": "facebook/opt-125m", "prompt": prompt, "max_tokens": 10, - "best_of": 20, "n": 3, "use_beam_search": "true", "temperature": 0.0, diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 9e461d4e..8bd64923 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -617,7 +617,6 @@ def test_schedule_decode_blocks_to_copy_update(): num_gpu_blocks=16) _, seq_group = create_dummy_prompt("1", prompt_length=60, - best_of=2, block_size=block_size) curr_loras = None scheduler._allocate_and_set_running(seq_group) @@ -686,7 +685,6 @@ def test_schedule_swapped_cannot_swap_in(): for i in range(2): _, seq_group = create_dummy_prompt(str(i), prompt_length=60, - best_of=2, block_size=block_size) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) @@ -717,7 +715,6 @@ def test_infeasible_swap(): for i in range(2): _, seq_group = create_dummy_prompt(str(i), prompt_length=60, - best_of=2, block_size=block_size) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) @@ -747,7 +744,6 @@ def test_schedule_swapped_blocks_to_copy(): curr_loras = None _, seq_group = create_dummy_prompt("1", prompt_length=60, - best_of=2, block_size=block_size) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) diff --git a/tests/core/utils.py b/tests/core/utils.py index ba4265e3..ea18b879 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -18,7 +18,6 @@ def create_dummy_prompt( prompt_length: int = -1, block_size: Optional[int] = None, lora_request: Optional[LoRARequest] = None, - best_of: int = 1, prompt_tokens: Optional[list[int]] = None, min_tokens: int = 0, max_tokens: int = 16, @@ -32,17 +31,19 @@ def create_dummy_prompt( prompt_tokens = list(range(prompt_length)) prompt_str = " ".join([str(t) for t in prompt_tokens]) - prompt = Sequence(int(request_id), - inputs=token_inputs(prompt_tokens, prompt=prompt_str), - block_size=block_size) - seq_group = SequenceGroup(request_id=request_id, - seqs=[prompt], - arrival_time=time.time(), - sampling_params=SamplingParams( - best_of=best_of, - max_tokens=max_tokens, - min_tokens=min_tokens), - lora_request=lora_request) + prompt = Sequence( + int(request_id), + inputs=token_inputs(prompt_tokens, prompt=prompt_str), + block_size=block_size, + ) + seq_group = SequenceGroup( + request_id=request_id, + seqs=[prompt], + arrival_time=time.time(), + sampling_params=SamplingParams(max_tokens=max_tokens, + min_tokens=min_tokens), + lora_request=lora_request, + ) return prompt, seq_group @@ -72,7 +73,6 @@ def create_dummy_prompt_encoder_decoder( encoder_prompt_length: int, block_size: Optional[int] = None, lora_request: Optional[LoRARequest] = None, - best_of: int = 1, ) -> tuple[Sequence, Sequence, SequenceGroup]: if not block_size: block_size = decoder_prompt_length @@ -102,7 +102,6 @@ def create_dummy_prompt_encoder_decoder( seq_group = SequenceGroup(request_id=request_id, seqs=[decoder_prompt], - sampling_params=SamplingParams(best_of=best_of), arrival_time=time.time(), lora_request=lora_request, encoder_seq=encoder_prompt) diff --git a/tests/v1/sample/test_sampling_params_e2e.py b/tests/v1/sample/test_sampling_params_e2e.py index e47f13f0..f17d4b77 100644 --- a/tests/v1/sample/test_sampling_params_e2e.py +++ b/tests/v1/sample/test_sampling_params_e2e.py @@ -25,14 +25,6 @@ def test_n_gt_1(model): assert len(outputs[0].outputs) == 3 -def test_best_of(model): - """Raise a ValueError since best_of is deprecated.""" - - params = SamplingParams(n=2, best_of=3) - with pytest.raises(ValueError): - _ = model.generate(PROMPT, params) - - def test_penalties(model): """Check that we do not get errors if applied.""" diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index fc585ee9..dd46a137 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -97,10 +97,7 @@ class LLM: throughput. However, if the value is too high, it may cause out-of- memory (OOM) errors. swap_space: The size (GiB) of CPU memory per GPU to use as swap space. - This can be used for temporarily storing the states of the requests - when their `best_of` sampling parameters are larger than 1. If all - requests will have `best_of=1`, you can safely set this to 0. - Otherwise, too small values may cause out-of-memory (OOM) errors. + Too small values may cause out-of-memory (OOM) errors. cpu_offload_gb: The size (GiB) of CPU memory to use for offloading the model weights. This virtually increases the GPU memory space you can use to hold the model weights, at the cost of CPU-GPU data diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 2c740caf..4c4d86fd 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -242,7 +242,6 @@ class ChatCompletionRequest(OpenAIBaseModel): user: Optional[str] = None # doc: begin-chat-completion-sampling-params - best_of: Optional[int] = None use_beam_search: bool = False top_k: Optional[int] = None min_p: Optional[float] = None @@ -479,7 +478,6 @@ class ChatCompletionRequest(OpenAIBaseModel): return SamplingParams.from_optional( n=self.n, - best_of=self.best_of, presence_penalty=self.presence_penalty, frequency_penalty=self.frequency_penalty, repetition_penalty=repetition_penalty, @@ -650,7 +648,6 @@ class CompletionRequest(OpenAIBaseModel): # https://platform.openai.com/docs/api-reference/completions/create model: Optional[str] = None prompt: Union[list[int], list[list[int]], str, list[str]] - best_of: Optional[int] = None echo: Optional[bool] = False frequency_penalty: Optional[float] = 0.0 logit_bias: Optional[dict[str, float]] = None @@ -848,7 +845,6 @@ class CompletionRequest(OpenAIBaseModel): return SamplingParams.from_optional( n=self.n, - best_of=self.best_of, presence_penalty=self.presence_penalty, frequency_penalty=self.frequency_penalty, repetition_penalty=repetition_penalty, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index edcf1b08..592f213b 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -168,12 +168,8 @@ class OpenAIServingCompletion(OpenAIServing): model_name = self._get_model_name(request.model, lora_request) num_prompts = len(engine_prompts) - # Similar to the OpenAI API, when n != best_of, we do not stream the - # results. In addition, we do not stream the results when use - # beam search. - stream = (request.stream - and (request.best_of is None or request.n == request.best_of) - and not request.use_beam_search) + # We do not stream the results when use beam search. + stream = (request.stream and not request.use_beam_search) # Streaming response if stream: diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 17e4e433..599d52ee 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -116,10 +116,6 @@ class SamplingParams( Args: n: Number of output sequences to return for the given prompt. - best_of: Number of output sequences that are generated from the prompt. - From these `best_of` sequences, the top `n` sequences are returned. - `best_of` must be greater than or equal to `n`. By default, - `best_of` is set to `n`. presence_penalty: Float that penalizes new tokens based on whether they appear in the generated text so far. Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to repeat @@ -187,7 +183,6 @@ class SamplingParams( """ n: int = 1 - best_of: Optional[int] = None _real_n: Optional[int] = None presence_penalty: float = 0.0 frequency_penalty: float = 0.0 @@ -231,7 +226,6 @@ class SamplingParams( @staticmethod def from_optional( n: Optional[int] = 1, - best_of: Optional[int] = None, presence_penalty: Optional[float] = 0.0, frequency_penalty: Optional[float] = 0.0, repetition_penalty: Optional[float] = 1.0, @@ -270,7 +264,6 @@ class SamplingParams( return SamplingParams( n=1 if n is None else n, - best_of=best_of, presence_penalty=0.0 if presence_penalty is None else presence_penalty, frequency_penalty=0.0 @@ -303,20 +296,6 @@ class SamplingParams( ) def __post_init__(self) -> None: - # how we deal with `best_of``: - # if `best_of`` is not set, we default to `n`; - # if `best_of`` is set, we set `n`` to `best_of`, - # and set `_real_n`` to the original `n`. - # when we return the result, we will check - # if we need to return `n` or `_real_n` results - if self.best_of: - if self.best_of < self.n: - raise ValueError( - f"best_of must be greater than or equal to n, " - f"got n={self.n} and best_of={self.best_of}.") - if not self._real_n: - self._real_n = self.n - self.n = self.best_of if 0 < self.temperature < _MAX_TEMP: logger.warning( @@ -423,9 +402,6 @@ class SamplingParams( raise ValueError( "stop strings are only supported when detokenize is True. " "Set detokenize=True to use stop.") - if self.best_of != self._real_n and self.output_kind == ( - RequestOutputKind.DELTA): - raise ValueError("best_of must equal n to use output_kind=DELTA") def _verify_greedy_sampling(self) -> None: if self.n > 1: diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 713a5d38..6a2c1c54 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -93,9 +93,6 @@ class Processor: self, params: SamplingParams, ) -> None: - # Best of not yet supported. - if params.best_of: - raise ValueError("VLLM V1 does not yet support best_of.") # Bad words not yet supported. if params.bad_words: raise ValueError("VLLM V1 does not yet support bad_words.")