Deprecate best_of
Sampling Parameter in anticipation for vLLM V1 (#13997)
Signed-off-by: vincent-4 <vincentzhongy+githubvincent4@gmail.com> Signed-off-by: Brayden Zhong <b8zhong@uwaterloo.ca> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Brayden Zhong <b8zhong@uwaterloo.ca> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
a32c8669ca
commit
a4f1ee35d6
@ -27,7 +27,6 @@ class RequestFuncInput:
|
|||||||
output_len: int
|
output_len: int
|
||||||
model: str
|
model: str
|
||||||
model_name: Optional[str] = None
|
model_name: Optional[str] = None
|
||||||
best_of: int = 1
|
|
||||||
logprobs: Optional[int] = None
|
logprobs: Optional[int] = None
|
||||||
extra_body: Optional[dict] = None
|
extra_body: Optional[dict] = None
|
||||||
multi_modal_content: Optional[dict] = None
|
multi_modal_content: Optional[dict] = None
|
||||||
@ -58,7 +57,6 @@ async def async_request_tgi(
|
|||||||
async with aiohttp.ClientSession(trust_env=True,
|
async with aiohttp.ClientSession(trust_env=True,
|
||||||
timeout=AIOHTTP_TIMEOUT) as session:
|
timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
params = {
|
params = {
|
||||||
"best_of": request_func_input.best_of,
|
|
||||||
"max_new_tokens": request_func_input.output_len,
|
"max_new_tokens": request_func_input.output_len,
|
||||||
"do_sample": True,
|
"do_sample": True,
|
||||||
"temperature": 0.01, # TGI does not accept 0.0 temperature.
|
"temperature": 0.01, # TGI does not accept 0.0 temperature.
|
||||||
@ -130,7 +128,6 @@ async def async_request_trt_llm(
|
|||||||
|
|
||||||
async with aiohttp.ClientSession(trust_env=True,
|
async with aiohttp.ClientSession(trust_env=True,
|
||||||
timeout=AIOHTTP_TIMEOUT) as session:
|
timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
assert request_func_input.best_of == 1
|
|
||||||
payload = {
|
payload = {
|
||||||
"accumulate_tokens": True,
|
"accumulate_tokens": True,
|
||||||
"text_input": request_func_input.prompt,
|
"text_input": request_func_input.prompt,
|
||||||
@ -195,7 +192,6 @@ async def async_request_deepspeed_mii(
|
|||||||
) -> RequestFuncOutput:
|
) -> RequestFuncOutput:
|
||||||
async with aiohttp.ClientSession(trust_env=True,
|
async with aiohttp.ClientSession(trust_env=True,
|
||||||
timeout=AIOHTTP_TIMEOUT) as session:
|
timeout=AIOHTTP_TIMEOUT) as session:
|
||||||
assert request_func_input.best_of == 1
|
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"prompt": request_func_input.prompt,
|
"prompt": request_func_input.prompt,
|
||||||
@ -249,7 +245,6 @@ async def async_request_openai_completions(
|
|||||||
if request_func_input.model_name else request_func_input.model,
|
if request_func_input.model_name else request_func_input.model,
|
||||||
"prompt": request_func_input.prompt,
|
"prompt": request_func_input.prompt,
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
"best_of": request_func_input.best_of,
|
|
||||||
"max_tokens": request_func_input.output_len,
|
"max_tokens": request_func_input.output_len,
|
||||||
"logprobs": request_func_input.logprobs,
|
"logprobs": request_func_input.logprobs,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
|
@ -560,7 +560,6 @@ async def benchmark(
|
|||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
input_requests: list[tuple[str, int, int]],
|
input_requests: list[tuple[str, int, int]],
|
||||||
logprobs: Optional[int],
|
logprobs: Optional[int],
|
||||||
best_of: int,
|
|
||||||
request_rate: float,
|
request_rate: float,
|
||||||
burstiness: float,
|
burstiness: float,
|
||||||
disable_tqdm: bool,
|
disable_tqdm: bool,
|
||||||
@ -592,7 +591,6 @@ async def benchmark(
|
|||||||
prompt_len=test_prompt_len,
|
prompt_len=test_prompt_len,
|
||||||
output_len=test_output_len,
|
output_len=test_output_len,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
best_of=best_of,
|
|
||||||
multi_modal_content=test_mm_content,
|
multi_modal_content=test_mm_content,
|
||||||
ignore_eos=ignore_eos,
|
ignore_eos=ignore_eos,
|
||||||
)
|
)
|
||||||
@ -619,7 +617,6 @@ async def benchmark(
|
|||||||
prompt_len=test_prompt_len,
|
prompt_len=test_prompt_len,
|
||||||
output_len=test_output_len,
|
output_len=test_output_len,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
best_of=best_of,
|
|
||||||
multi_modal_content=test_mm_content,
|
multi_modal_content=test_mm_content,
|
||||||
ignore_eos=ignore_eos)
|
ignore_eos=ignore_eos)
|
||||||
profile_output = await request_func(request_func_input=profile_input)
|
profile_output = await request_func(request_func_input=profile_input)
|
||||||
@ -668,7 +665,6 @@ async def benchmark(
|
|||||||
prompt_len=prompt_len,
|
prompt_len=prompt_len,
|
||||||
output_len=output_len,
|
output_len=output_len,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
best_of=best_of,
|
|
||||||
multi_modal_content=mm_content,
|
multi_modal_content=mm_content,
|
||||||
ignore_eos=ignore_eos)
|
ignore_eos=ignore_eos)
|
||||||
tasks.append(
|
tasks.append(
|
||||||
@ -686,7 +682,6 @@ async def benchmark(
|
|||||||
prompt_len=test_prompt_len,
|
prompt_len=test_prompt_len,
|
||||||
output_len=test_output_len,
|
output_len=test_output_len,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
best_of=best_of,
|
|
||||||
)
|
)
|
||||||
profile_output = await request_func(request_func_input=profile_input)
|
profile_output = await request_func(request_func_input=profile_input)
|
||||||
if profile_output.success:
|
if profile_output.success:
|
||||||
@ -958,7 +953,6 @@ def main(args: argparse.Namespace):
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
input_requests=input_requests,
|
input_requests=input_requests,
|
||||||
logprobs=args.logprobs,
|
logprobs=args.logprobs,
|
||||||
best_of=args.best_of,
|
|
||||||
request_rate=args.request_rate,
|
request_rate=args.request_rate,
|
||||||
burstiness=args.burstiness,
|
burstiness=args.burstiness,
|
||||||
disable_tqdm=args.disable_tqdm,
|
disable_tqdm=args.disable_tqdm,
|
||||||
@ -983,7 +977,6 @@ def main(args: argparse.Namespace):
|
|||||||
result_json["backend"] = backend
|
result_json["backend"] = backend
|
||||||
result_json["model_id"] = model_id
|
result_json["model_id"] = model_id
|
||||||
result_json["tokenizer_id"] = tokenizer_id
|
result_json["tokenizer_id"] = tokenizer_id
|
||||||
result_json["best_of"] = args.best_of
|
|
||||||
result_json["num_prompts"] = args.num_prompts
|
result_json["num_prompts"] = args.num_prompts
|
||||||
|
|
||||||
# Metadata
|
# Metadata
|
||||||
@ -1081,13 +1074,6 @@ if __name__ == "__main__":
|
|||||||
help=
|
help=
|
||||||
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--best-of",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Generates `best_of` sequences per prompt and "
|
|
||||||
"returns the best one.",
|
|
||||||
)
|
|
||||||
parser.add_argument("--use-beam-search", action="store_true")
|
parser.add_argument("--use-beam-search", action="store_true")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-prompts",
|
"--num-prompts",
|
||||||
|
@ -15,7 +15,6 @@ def create_test_prompts() -> list[tuple[str, SamplingParams]]:
|
|||||||
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
|
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
|
||||||
("What is the meaning of life?",
|
("What is the meaning of life?",
|
||||||
SamplingParams(n=2,
|
SamplingParams(n=2,
|
||||||
best_of=5,
|
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
frequency_penalty=0.1)),
|
frequency_penalty=0.1)),
|
||||||
|
@ -28,7 +28,6 @@ with tracer.start_as_current_span("client-span", kind=SpanKind.CLIENT) as span:
|
|||||||
"model": "facebook/opt-125m",
|
"model": "facebook/opt-125m",
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"max_tokens": 10,
|
"max_tokens": 10,
|
||||||
"best_of": 20,
|
|
||||||
"n": 3,
|
"n": 3,
|
||||||
"use_beam_search": "true",
|
"use_beam_search": "true",
|
||||||
"temperature": 0.0,
|
"temperature": 0.0,
|
||||||
|
@ -617,7 +617,6 @@ def test_schedule_decode_blocks_to_copy_update():
|
|||||||
num_gpu_blocks=16)
|
num_gpu_blocks=16)
|
||||||
_, seq_group = create_dummy_prompt("1",
|
_, seq_group = create_dummy_prompt("1",
|
||||||
prompt_length=60,
|
prompt_length=60,
|
||||||
best_of=2,
|
|
||||||
block_size=block_size)
|
block_size=block_size)
|
||||||
curr_loras = None
|
curr_loras = None
|
||||||
scheduler._allocate_and_set_running(seq_group)
|
scheduler._allocate_and_set_running(seq_group)
|
||||||
@ -686,7 +685,6 @@ def test_schedule_swapped_cannot_swap_in():
|
|||||||
for i in range(2):
|
for i in range(2):
|
||||||
_, seq_group = create_dummy_prompt(str(i),
|
_, seq_group = create_dummy_prompt(str(i),
|
||||||
prompt_length=60,
|
prompt_length=60,
|
||||||
best_of=2,
|
|
||||||
block_size=block_size)
|
block_size=block_size)
|
||||||
scheduler._allocate_and_set_running(seq_group)
|
scheduler._allocate_and_set_running(seq_group)
|
||||||
append_new_token_seq_group(60, seq_group, 1)
|
append_new_token_seq_group(60, seq_group, 1)
|
||||||
@ -717,7 +715,6 @@ def test_infeasible_swap():
|
|||||||
for i in range(2):
|
for i in range(2):
|
||||||
_, seq_group = create_dummy_prompt(str(i),
|
_, seq_group = create_dummy_prompt(str(i),
|
||||||
prompt_length=60,
|
prompt_length=60,
|
||||||
best_of=2,
|
|
||||||
block_size=block_size)
|
block_size=block_size)
|
||||||
scheduler._allocate_and_set_running(seq_group)
|
scheduler._allocate_and_set_running(seq_group)
|
||||||
append_new_token_seq_group(60, seq_group, 1)
|
append_new_token_seq_group(60, seq_group, 1)
|
||||||
@ -747,7 +744,6 @@ def test_schedule_swapped_blocks_to_copy():
|
|||||||
curr_loras = None
|
curr_loras = None
|
||||||
_, seq_group = create_dummy_prompt("1",
|
_, seq_group = create_dummy_prompt("1",
|
||||||
prompt_length=60,
|
prompt_length=60,
|
||||||
best_of=2,
|
|
||||||
block_size=block_size)
|
block_size=block_size)
|
||||||
scheduler._allocate_and_set_running(seq_group)
|
scheduler._allocate_and_set_running(seq_group)
|
||||||
append_new_token_seq_group(60, seq_group, 1)
|
append_new_token_seq_group(60, seq_group, 1)
|
||||||
|
@ -18,7 +18,6 @@ def create_dummy_prompt(
|
|||||||
prompt_length: int = -1,
|
prompt_length: int = -1,
|
||||||
block_size: Optional[int] = None,
|
block_size: Optional[int] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
best_of: int = 1,
|
|
||||||
prompt_tokens: Optional[list[int]] = None,
|
prompt_tokens: Optional[list[int]] = None,
|
||||||
min_tokens: int = 0,
|
min_tokens: int = 0,
|
||||||
max_tokens: int = 16,
|
max_tokens: int = 16,
|
||||||
@ -32,17 +31,19 @@ def create_dummy_prompt(
|
|||||||
prompt_tokens = list(range(prompt_length))
|
prompt_tokens = list(range(prompt_length))
|
||||||
|
|
||||||
prompt_str = " ".join([str(t) for t in prompt_tokens])
|
prompt_str = " ".join([str(t) for t in prompt_tokens])
|
||||||
prompt = Sequence(int(request_id),
|
prompt = Sequence(
|
||||||
inputs=token_inputs(prompt_tokens, prompt=prompt_str),
|
int(request_id),
|
||||||
block_size=block_size)
|
inputs=token_inputs(prompt_tokens, prompt=prompt_str),
|
||||||
seq_group = SequenceGroup(request_id=request_id,
|
block_size=block_size,
|
||||||
seqs=[prompt],
|
)
|
||||||
arrival_time=time.time(),
|
seq_group = SequenceGroup(
|
||||||
sampling_params=SamplingParams(
|
request_id=request_id,
|
||||||
best_of=best_of,
|
seqs=[prompt],
|
||||||
max_tokens=max_tokens,
|
arrival_time=time.time(),
|
||||||
min_tokens=min_tokens),
|
sampling_params=SamplingParams(max_tokens=max_tokens,
|
||||||
lora_request=lora_request)
|
min_tokens=min_tokens),
|
||||||
|
lora_request=lora_request,
|
||||||
|
)
|
||||||
|
|
||||||
return prompt, seq_group
|
return prompt, seq_group
|
||||||
|
|
||||||
@ -72,7 +73,6 @@ def create_dummy_prompt_encoder_decoder(
|
|||||||
encoder_prompt_length: int,
|
encoder_prompt_length: int,
|
||||||
block_size: Optional[int] = None,
|
block_size: Optional[int] = None,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional[LoRARequest] = None,
|
||||||
best_of: int = 1,
|
|
||||||
) -> tuple[Sequence, Sequence, SequenceGroup]:
|
) -> tuple[Sequence, Sequence, SequenceGroup]:
|
||||||
if not block_size:
|
if not block_size:
|
||||||
block_size = decoder_prompt_length
|
block_size = decoder_prompt_length
|
||||||
@ -102,7 +102,6 @@ def create_dummy_prompt_encoder_decoder(
|
|||||||
|
|
||||||
seq_group = SequenceGroup(request_id=request_id,
|
seq_group = SequenceGroup(request_id=request_id,
|
||||||
seqs=[decoder_prompt],
|
seqs=[decoder_prompt],
|
||||||
sampling_params=SamplingParams(best_of=best_of),
|
|
||||||
arrival_time=time.time(),
|
arrival_time=time.time(),
|
||||||
lora_request=lora_request,
|
lora_request=lora_request,
|
||||||
encoder_seq=encoder_prompt)
|
encoder_seq=encoder_prompt)
|
||||||
|
@ -25,14 +25,6 @@ def test_n_gt_1(model):
|
|||||||
assert len(outputs[0].outputs) == 3
|
assert len(outputs[0].outputs) == 3
|
||||||
|
|
||||||
|
|
||||||
def test_best_of(model):
|
|
||||||
"""Raise a ValueError since best_of is deprecated."""
|
|
||||||
|
|
||||||
params = SamplingParams(n=2, best_of=3)
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
_ = model.generate(PROMPT, params)
|
|
||||||
|
|
||||||
|
|
||||||
def test_penalties(model):
|
def test_penalties(model):
|
||||||
"""Check that we do not get errors if applied."""
|
"""Check that we do not get errors if applied."""
|
||||||
|
|
||||||
|
@ -97,10 +97,7 @@ class LLM:
|
|||||||
throughput. However, if the value is too high, it may cause out-of-
|
throughput. However, if the value is too high, it may cause out-of-
|
||||||
memory (OOM) errors.
|
memory (OOM) errors.
|
||||||
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
|
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
|
||||||
This can be used for temporarily storing the states of the requests
|
Too small values may cause out-of-memory (OOM) errors.
|
||||||
when their `best_of` sampling parameters are larger than 1. If all
|
|
||||||
requests will have `best_of=1`, you can safely set this to 0.
|
|
||||||
Otherwise, too small values may cause out-of-memory (OOM) errors.
|
|
||||||
cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
|
cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
|
||||||
the model weights. This virtually increases the GPU memory space
|
the model weights. This virtually increases the GPU memory space
|
||||||
you can use to hold the model weights, at the cost of CPU-GPU data
|
you can use to hold the model weights, at the cost of CPU-GPU data
|
||||||
|
@ -242,7 +242,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
user: Optional[str] = None
|
user: Optional[str] = None
|
||||||
|
|
||||||
# doc: begin-chat-completion-sampling-params
|
# doc: begin-chat-completion-sampling-params
|
||||||
best_of: Optional[int] = None
|
|
||||||
use_beam_search: bool = False
|
use_beam_search: bool = False
|
||||||
top_k: Optional[int] = None
|
top_k: Optional[int] = None
|
||||||
min_p: Optional[float] = None
|
min_p: Optional[float] = None
|
||||||
@ -479,7 +478,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
|
|
||||||
return SamplingParams.from_optional(
|
return SamplingParams.from_optional(
|
||||||
n=self.n,
|
n=self.n,
|
||||||
best_of=self.best_of,
|
|
||||||
presence_penalty=self.presence_penalty,
|
presence_penalty=self.presence_penalty,
|
||||||
frequency_penalty=self.frequency_penalty,
|
frequency_penalty=self.frequency_penalty,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
@ -650,7 +648,6 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
# https://platform.openai.com/docs/api-reference/completions/create
|
# https://platform.openai.com/docs/api-reference/completions/create
|
||||||
model: Optional[str] = None
|
model: Optional[str] = None
|
||||||
prompt: Union[list[int], list[list[int]], str, list[str]]
|
prompt: Union[list[int], list[list[int]], str, list[str]]
|
||||||
best_of: Optional[int] = None
|
|
||||||
echo: Optional[bool] = False
|
echo: Optional[bool] = False
|
||||||
frequency_penalty: Optional[float] = 0.0
|
frequency_penalty: Optional[float] = 0.0
|
||||||
logit_bias: Optional[dict[str, float]] = None
|
logit_bias: Optional[dict[str, float]] = None
|
||||||
@ -848,7 +845,6 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
|
|
||||||
return SamplingParams.from_optional(
|
return SamplingParams.from_optional(
|
||||||
n=self.n,
|
n=self.n,
|
||||||
best_of=self.best_of,
|
|
||||||
presence_penalty=self.presence_penalty,
|
presence_penalty=self.presence_penalty,
|
||||||
frequency_penalty=self.frequency_penalty,
|
frequency_penalty=self.frequency_penalty,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
|
@ -168,12 +168,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
model_name = self._get_model_name(request.model, lora_request)
|
model_name = self._get_model_name(request.model, lora_request)
|
||||||
num_prompts = len(engine_prompts)
|
num_prompts = len(engine_prompts)
|
||||||
|
|
||||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
# We do not stream the results when use beam search.
|
||||||
# results. In addition, we do not stream the results when use
|
stream = (request.stream and not request.use_beam_search)
|
||||||
# beam search.
|
|
||||||
stream = (request.stream
|
|
||||||
and (request.best_of is None or request.n == request.best_of)
|
|
||||||
and not request.use_beam_search)
|
|
||||||
|
|
||||||
# Streaming response
|
# Streaming response
|
||||||
if stream:
|
if stream:
|
||||||
|
@ -116,10 +116,6 @@ class SamplingParams(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
n: Number of output sequences to return for the given prompt.
|
n: Number of output sequences to return for the given prompt.
|
||||||
best_of: Number of output sequences that are generated from the prompt.
|
|
||||||
From these `best_of` sequences, the top `n` sequences are returned.
|
|
||||||
`best_of` must be greater than or equal to `n`. By default,
|
|
||||||
`best_of` is set to `n`.
|
|
||||||
presence_penalty: Float that penalizes new tokens based on whether they
|
presence_penalty: Float that penalizes new tokens based on whether they
|
||||||
appear in the generated text so far. Values > 0 encourage the model
|
appear in the generated text so far. Values > 0 encourage the model
|
||||||
to use new tokens, while values < 0 encourage the model to repeat
|
to use new tokens, while values < 0 encourage the model to repeat
|
||||||
@ -187,7 +183,6 @@ class SamplingParams(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
n: int = 1
|
n: int = 1
|
||||||
best_of: Optional[int] = None
|
|
||||||
_real_n: Optional[int] = None
|
_real_n: Optional[int] = None
|
||||||
presence_penalty: float = 0.0
|
presence_penalty: float = 0.0
|
||||||
frequency_penalty: float = 0.0
|
frequency_penalty: float = 0.0
|
||||||
@ -231,7 +226,6 @@ class SamplingParams(
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def from_optional(
|
def from_optional(
|
||||||
n: Optional[int] = 1,
|
n: Optional[int] = 1,
|
||||||
best_of: Optional[int] = None,
|
|
||||||
presence_penalty: Optional[float] = 0.0,
|
presence_penalty: Optional[float] = 0.0,
|
||||||
frequency_penalty: Optional[float] = 0.0,
|
frequency_penalty: Optional[float] = 0.0,
|
||||||
repetition_penalty: Optional[float] = 1.0,
|
repetition_penalty: Optional[float] = 1.0,
|
||||||
@ -270,7 +264,6 @@ class SamplingParams(
|
|||||||
|
|
||||||
return SamplingParams(
|
return SamplingParams(
|
||||||
n=1 if n is None else n,
|
n=1 if n is None else n,
|
||||||
best_of=best_of,
|
|
||||||
presence_penalty=0.0
|
presence_penalty=0.0
|
||||||
if presence_penalty is None else presence_penalty,
|
if presence_penalty is None else presence_penalty,
|
||||||
frequency_penalty=0.0
|
frequency_penalty=0.0
|
||||||
@ -303,20 +296,6 @@ class SamplingParams(
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
# how we deal with `best_of``:
|
|
||||||
# if `best_of`` is not set, we default to `n`;
|
|
||||||
# if `best_of`` is set, we set `n`` to `best_of`,
|
|
||||||
# and set `_real_n`` to the original `n`.
|
|
||||||
# when we return the result, we will check
|
|
||||||
# if we need to return `n` or `_real_n` results
|
|
||||||
if self.best_of:
|
|
||||||
if self.best_of < self.n:
|
|
||||||
raise ValueError(
|
|
||||||
f"best_of must be greater than or equal to n, "
|
|
||||||
f"got n={self.n} and best_of={self.best_of}.")
|
|
||||||
if not self._real_n:
|
|
||||||
self._real_n = self.n
|
|
||||||
self.n = self.best_of
|
|
||||||
|
|
||||||
if 0 < self.temperature < _MAX_TEMP:
|
if 0 < self.temperature < _MAX_TEMP:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@ -423,9 +402,6 @@ class SamplingParams(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"stop strings are only supported when detokenize is True. "
|
"stop strings are only supported when detokenize is True. "
|
||||||
"Set detokenize=True to use stop.")
|
"Set detokenize=True to use stop.")
|
||||||
if self.best_of != self._real_n and self.output_kind == (
|
|
||||||
RequestOutputKind.DELTA):
|
|
||||||
raise ValueError("best_of must equal n to use output_kind=DELTA")
|
|
||||||
|
|
||||||
def _verify_greedy_sampling(self) -> None:
|
def _verify_greedy_sampling(self) -> None:
|
||||||
if self.n > 1:
|
if self.n > 1:
|
||||||
|
@ -93,9 +93,6 @@ class Processor:
|
|||||||
self,
|
self,
|
||||||
params: SamplingParams,
|
params: SamplingParams,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Best of not yet supported.
|
|
||||||
if params.best_of:
|
|
||||||
raise ValueError("VLLM V1 does not yet support best_of.")
|
|
||||||
# Bad words not yet supported.
|
# Bad words not yet supported.
|
||||||
if params.bad_words:
|
if params.bad_words:
|
||||||
raise ValueError("VLLM V1 does not yet support bad_words.")
|
raise ValueError("VLLM V1 does not yet support bad_words.")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user