Deprecate best_of
Sampling Parameter in anticipation for vLLM V1 (#13997)
Signed-off-by: vincent-4 <vincentzhongy+githubvincent4@gmail.com> Signed-off-by: Brayden Zhong <b8zhong@uwaterloo.ca> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Brayden Zhong <b8zhong@uwaterloo.ca> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
a32c8669ca
commit
a4f1ee35d6
@ -27,7 +27,6 @@ class RequestFuncInput:
|
||||
output_len: int
|
||||
model: str
|
||||
model_name: Optional[str] = None
|
||||
best_of: int = 1
|
||||
logprobs: Optional[int] = None
|
||||
extra_body: Optional[dict] = None
|
||||
multi_modal_content: Optional[dict] = None
|
||||
@ -58,7 +57,6 @@ async def async_request_tgi(
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
params = {
|
||||
"best_of": request_func_input.best_of,
|
||||
"max_new_tokens": request_func_input.output_len,
|
||||
"do_sample": True,
|
||||
"temperature": 0.01, # TGI does not accept 0.0 temperature.
|
||||
@ -130,7 +128,6 @@ async def async_request_trt_llm(
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
assert request_func_input.best_of == 1
|
||||
payload = {
|
||||
"accumulate_tokens": True,
|
||||
"text_input": request_func_input.prompt,
|
||||
@ -195,7 +192,6 @@ async def async_request_deepspeed_mii(
|
||||
) -> RequestFuncOutput:
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
assert request_func_input.best_of == 1
|
||||
|
||||
payload = {
|
||||
"prompt": request_func_input.prompt,
|
||||
@ -249,7 +245,6 @@ async def async_request_openai_completions(
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
"prompt": request_func_input.prompt,
|
||||
"temperature": 0.0,
|
||||
"best_of": request_func_input.best_of,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
"logprobs": request_func_input.logprobs,
|
||||
"stream": True,
|
||||
|
@ -560,7 +560,6 @@ async def benchmark(
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
input_requests: list[tuple[str, int, int]],
|
||||
logprobs: Optional[int],
|
||||
best_of: int,
|
||||
request_rate: float,
|
||||
burstiness: float,
|
||||
disable_tqdm: bool,
|
||||
@ -592,7 +591,6 @@ async def benchmark(
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=test_output_len,
|
||||
logprobs=logprobs,
|
||||
best_of=best_of,
|
||||
multi_modal_content=test_mm_content,
|
||||
ignore_eos=ignore_eos,
|
||||
)
|
||||
@ -619,7 +617,6 @@ async def benchmark(
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=test_output_len,
|
||||
logprobs=logprobs,
|
||||
best_of=best_of,
|
||||
multi_modal_content=test_mm_content,
|
||||
ignore_eos=ignore_eos)
|
||||
profile_output = await request_func(request_func_input=profile_input)
|
||||
@ -668,7 +665,6 @@ async def benchmark(
|
||||
prompt_len=prompt_len,
|
||||
output_len=output_len,
|
||||
logprobs=logprobs,
|
||||
best_of=best_of,
|
||||
multi_modal_content=mm_content,
|
||||
ignore_eos=ignore_eos)
|
||||
tasks.append(
|
||||
@ -686,7 +682,6 @@ async def benchmark(
|
||||
prompt_len=test_prompt_len,
|
||||
output_len=test_output_len,
|
||||
logprobs=logprobs,
|
||||
best_of=best_of,
|
||||
)
|
||||
profile_output = await request_func(request_func_input=profile_input)
|
||||
if profile_output.success:
|
||||
@ -958,7 +953,6 @@ def main(args: argparse.Namespace):
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
logprobs=args.logprobs,
|
||||
best_of=args.best_of,
|
||||
request_rate=args.request_rate,
|
||||
burstiness=args.burstiness,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
@ -983,7 +977,6 @@ def main(args: argparse.Namespace):
|
||||
result_json["backend"] = backend
|
||||
result_json["model_id"] = model_id
|
||||
result_json["tokenizer_id"] = tokenizer_id
|
||||
result_json["best_of"] = args.best_of
|
||||
result_json["num_prompts"] = args.num_prompts
|
||||
|
||||
# Metadata
|
||||
@ -1081,13 +1074,6 @@ if __name__ == "__main__":
|
||||
help=
|
||||
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
||||
)
|
||||
parser.add_argument(
|
||||
"--best-of",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Generates `best_of` sequences per prompt and "
|
||||
"returns the best one.",
|
||||
)
|
||||
parser.add_argument("--use-beam-search", action="store_true")
|
||||
parser.add_argument(
|
||||
"--num-prompts",
|
||||
|
@ -15,7 +15,6 @@ def create_test_prompts() -> list[tuple[str, SamplingParams]]:
|
||||
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
|
||||
("What is the meaning of life?",
|
||||
SamplingParams(n=2,
|
||||
best_of=5,
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
frequency_penalty=0.1)),
|
||||
|
@ -28,7 +28,6 @@ with tracer.start_as_current_span("client-span", kind=SpanKind.CLIENT) as span:
|
||||
"model": "facebook/opt-125m",
|
||||
"prompt": prompt,
|
||||
"max_tokens": 10,
|
||||
"best_of": 20,
|
||||
"n": 3,
|
||||
"use_beam_search": "true",
|
||||
"temperature": 0.0,
|
||||
|
@ -617,7 +617,6 @@ def test_schedule_decode_blocks_to_copy_update():
|
||||
num_gpu_blocks=16)
|
||||
_, seq_group = create_dummy_prompt("1",
|
||||
prompt_length=60,
|
||||
best_of=2,
|
||||
block_size=block_size)
|
||||
curr_loras = None
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
@ -686,7 +685,6 @@ def test_schedule_swapped_cannot_swap_in():
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(str(i),
|
||||
prompt_length=60,
|
||||
best_of=2,
|
||||
block_size=block_size)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
@ -717,7 +715,6 @@ def test_infeasible_swap():
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(str(i),
|
||||
prompt_length=60,
|
||||
best_of=2,
|
||||
block_size=block_size)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
@ -747,7 +744,6 @@ def test_schedule_swapped_blocks_to_copy():
|
||||
curr_loras = None
|
||||
_, seq_group = create_dummy_prompt("1",
|
||||
prompt_length=60,
|
||||
best_of=2,
|
||||
block_size=block_size)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
|
@ -18,7 +18,6 @@ def create_dummy_prompt(
|
||||
prompt_length: int = -1,
|
||||
block_size: Optional[int] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
best_of: int = 1,
|
||||
prompt_tokens: Optional[list[int]] = None,
|
||||
min_tokens: int = 0,
|
||||
max_tokens: int = 16,
|
||||
@ -32,17 +31,19 @@ def create_dummy_prompt(
|
||||
prompt_tokens = list(range(prompt_length))
|
||||
|
||||
prompt_str = " ".join([str(t) for t in prompt_tokens])
|
||||
prompt = Sequence(int(request_id),
|
||||
inputs=token_inputs(prompt_tokens, prompt=prompt_str),
|
||||
block_size=block_size)
|
||||
seq_group = SequenceGroup(request_id=request_id,
|
||||
seqs=[prompt],
|
||||
arrival_time=time.time(),
|
||||
sampling_params=SamplingParams(
|
||||
best_of=best_of,
|
||||
max_tokens=max_tokens,
|
||||
min_tokens=min_tokens),
|
||||
lora_request=lora_request)
|
||||
prompt = Sequence(
|
||||
int(request_id),
|
||||
inputs=token_inputs(prompt_tokens, prompt=prompt_str),
|
||||
block_size=block_size,
|
||||
)
|
||||
seq_group = SequenceGroup(
|
||||
request_id=request_id,
|
||||
seqs=[prompt],
|
||||
arrival_time=time.time(),
|
||||
sampling_params=SamplingParams(max_tokens=max_tokens,
|
||||
min_tokens=min_tokens),
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
return prompt, seq_group
|
||||
|
||||
@ -72,7 +73,6 @@ def create_dummy_prompt_encoder_decoder(
|
||||
encoder_prompt_length: int,
|
||||
block_size: Optional[int] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
best_of: int = 1,
|
||||
) -> tuple[Sequence, Sequence, SequenceGroup]:
|
||||
if not block_size:
|
||||
block_size = decoder_prompt_length
|
||||
@ -102,7 +102,6 @@ def create_dummy_prompt_encoder_decoder(
|
||||
|
||||
seq_group = SequenceGroup(request_id=request_id,
|
||||
seqs=[decoder_prompt],
|
||||
sampling_params=SamplingParams(best_of=best_of),
|
||||
arrival_time=time.time(),
|
||||
lora_request=lora_request,
|
||||
encoder_seq=encoder_prompt)
|
||||
|
@ -25,14 +25,6 @@ def test_n_gt_1(model):
|
||||
assert len(outputs[0].outputs) == 3
|
||||
|
||||
|
||||
def test_best_of(model):
|
||||
"""Raise a ValueError since best_of is deprecated."""
|
||||
|
||||
params = SamplingParams(n=2, best_of=3)
|
||||
with pytest.raises(ValueError):
|
||||
_ = model.generate(PROMPT, params)
|
||||
|
||||
|
||||
def test_penalties(model):
|
||||
"""Check that we do not get errors if applied."""
|
||||
|
||||
|
@ -97,10 +97,7 @@ class LLM:
|
||||
throughput. However, if the value is too high, it may cause out-of-
|
||||
memory (OOM) errors.
|
||||
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
|
||||
This can be used for temporarily storing the states of the requests
|
||||
when their `best_of` sampling parameters are larger than 1. If all
|
||||
requests will have `best_of=1`, you can safely set this to 0.
|
||||
Otherwise, too small values may cause out-of-memory (OOM) errors.
|
||||
Too small values may cause out-of-memory (OOM) errors.
|
||||
cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
|
||||
the model weights. This virtually increases the GPU memory space
|
||||
you can use to hold the model weights, at the cost of CPU-GPU data
|
||||
|
@ -242,7 +242,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
user: Optional[str] = None
|
||||
|
||||
# doc: begin-chat-completion-sampling-params
|
||||
best_of: Optional[int] = None
|
||||
use_beam_search: bool = False
|
||||
top_k: Optional[int] = None
|
||||
min_p: Optional[float] = None
|
||||
@ -479,7 +478,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
|
||||
return SamplingParams.from_optional(
|
||||
n=self.n,
|
||||
best_of=self.best_of,
|
||||
presence_penalty=self.presence_penalty,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
repetition_penalty=repetition_penalty,
|
||||
@ -650,7 +648,6 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
# https://platform.openai.com/docs/api-reference/completions/create
|
||||
model: Optional[str] = None
|
||||
prompt: Union[list[int], list[list[int]], str, list[str]]
|
||||
best_of: Optional[int] = None
|
||||
echo: Optional[bool] = False
|
||||
frequency_penalty: Optional[float] = 0.0
|
||||
logit_bias: Optional[dict[str, float]] = None
|
||||
@ -848,7 +845,6 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
|
||||
return SamplingParams.from_optional(
|
||||
n=self.n,
|
||||
best_of=self.best_of,
|
||||
presence_penalty=self.presence_penalty,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
repetition_penalty=repetition_penalty,
|
||||
|
@ -168,12 +168,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
model_name = self._get_model_name(request.model, lora_request)
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||
# results. In addition, we do not stream the results when use
|
||||
# beam search.
|
||||
stream = (request.stream
|
||||
and (request.best_of is None or request.n == request.best_of)
|
||||
and not request.use_beam_search)
|
||||
# We do not stream the results when use beam search.
|
||||
stream = (request.stream and not request.use_beam_search)
|
||||
|
||||
# Streaming response
|
||||
if stream:
|
||||
|
@ -116,10 +116,6 @@ class SamplingParams(
|
||||
|
||||
Args:
|
||||
n: Number of output sequences to return for the given prompt.
|
||||
best_of: Number of output sequences that are generated from the prompt.
|
||||
From these `best_of` sequences, the top `n` sequences are returned.
|
||||
`best_of` must be greater than or equal to `n`. By default,
|
||||
`best_of` is set to `n`.
|
||||
presence_penalty: Float that penalizes new tokens based on whether they
|
||||
appear in the generated text so far. Values > 0 encourage the model
|
||||
to use new tokens, while values < 0 encourage the model to repeat
|
||||
@ -187,7 +183,6 @@ class SamplingParams(
|
||||
"""
|
||||
|
||||
n: int = 1
|
||||
best_of: Optional[int] = None
|
||||
_real_n: Optional[int] = None
|
||||
presence_penalty: float = 0.0
|
||||
frequency_penalty: float = 0.0
|
||||
@ -231,7 +226,6 @@ class SamplingParams(
|
||||
@staticmethod
|
||||
def from_optional(
|
||||
n: Optional[int] = 1,
|
||||
best_of: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = 0.0,
|
||||
frequency_penalty: Optional[float] = 0.0,
|
||||
repetition_penalty: Optional[float] = 1.0,
|
||||
@ -270,7 +264,6 @@ class SamplingParams(
|
||||
|
||||
return SamplingParams(
|
||||
n=1 if n is None else n,
|
||||
best_of=best_of,
|
||||
presence_penalty=0.0
|
||||
if presence_penalty is None else presence_penalty,
|
||||
frequency_penalty=0.0
|
||||
@ -303,20 +296,6 @@ class SamplingParams(
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# how we deal with `best_of``:
|
||||
# if `best_of`` is not set, we default to `n`;
|
||||
# if `best_of`` is set, we set `n`` to `best_of`,
|
||||
# and set `_real_n`` to the original `n`.
|
||||
# when we return the result, we will check
|
||||
# if we need to return `n` or `_real_n` results
|
||||
if self.best_of:
|
||||
if self.best_of < self.n:
|
||||
raise ValueError(
|
||||
f"best_of must be greater than or equal to n, "
|
||||
f"got n={self.n} and best_of={self.best_of}.")
|
||||
if not self._real_n:
|
||||
self._real_n = self.n
|
||||
self.n = self.best_of
|
||||
|
||||
if 0 < self.temperature < _MAX_TEMP:
|
||||
logger.warning(
|
||||
@ -423,9 +402,6 @@ class SamplingParams(
|
||||
raise ValueError(
|
||||
"stop strings are only supported when detokenize is True. "
|
||||
"Set detokenize=True to use stop.")
|
||||
if self.best_of != self._real_n and self.output_kind == (
|
||||
RequestOutputKind.DELTA):
|
||||
raise ValueError("best_of must equal n to use output_kind=DELTA")
|
||||
|
||||
def _verify_greedy_sampling(self) -> None:
|
||||
if self.n > 1:
|
||||
|
@ -93,9 +93,6 @@ class Processor:
|
||||
self,
|
||||
params: SamplingParams,
|
||||
) -> None:
|
||||
# Best of not yet supported.
|
||||
if params.best_of:
|
||||
raise ValueError("VLLM V1 does not yet support best_of.")
|
||||
# Bad words not yet supported.
|
||||
if params.bad_words:
|
||||
raise ValueError("VLLM V1 does not yet support bad_words.")
|
||||
|
Loading…
x
Reference in New Issue
Block a user