[core] remove beam search from the core (#9105)

This commit is contained in:
youkaichao 2024-10-06 22:47:04 -07:00 committed by GitHub
parent c8f26bb636
commit 18b296fdb2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 98 additions and 596 deletions

View File

@ -23,7 +23,6 @@ class RequestFuncInput:
output_len: int
model: str
best_of: int = 1
use_beam_search: bool = False
logprobs: Optional[int] = None
multi_modal_content: Optional[dict] = None
ignore_eos: bool = False
@ -49,7 +48,6 @@ async def async_request_tgi(
assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
params = {
"best_of": request_func_input.best_of,
"max_new_tokens": request_func_input.output_len,
@ -121,7 +119,6 @@ async def async_request_trt_llm(
assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
assert request_func_input.best_of == 1
payload = {
"accumulate_tokens": True,
@ -187,7 +184,6 @@ async def async_request_deepspeed_mii(
) -> RequestFuncOutput:
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert request_func_input.best_of == 1
assert not request_func_input.use_beam_search
payload = {
"prompt": request_func_input.prompt,
@ -235,7 +231,6 @@ async def async_request_openai_completions(
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
payload = {
"model": request_func_input.model,
"prompt": request_func_input.prompt,
@ -317,7 +312,6 @@ async def async_request_openai_chat_completions(
), "OpenAI Chat Completions API URL must end with 'chat/completions'."
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
content = [{"type": "text", "text": request_func_input.prompt}]
if request_func_input.multi_modal_content:
content.append(request_func_input.multi_modal_content)

View File

@ -51,9 +51,8 @@ def main(args: argparse.Namespace):
sampling_params = SamplingParams(
n=args.n,
temperature=0.0 if args.use_beam_search else 1.0,
temperature=1.0,
top_p=1.0,
use_beam_search=args.use_beam_search,
ignore_eos=True,
max_tokens=args.output_len,
)

View File

@ -68,7 +68,6 @@ def run_vllm(
tensor_parallel_size: int,
seed: int,
n: int,
use_beam_search: bool,
trust_remote_code: bool,
dtype: str,
max_model_len: Optional[int],
@ -114,9 +113,8 @@ def run_vllm(
sampling_params.append(
SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
temperature=1.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True,
max_tokens=output_len,
))
@ -144,15 +142,16 @@ def main(args: argparse.Namespace):
args.output_len)
if args.backend == "vllm":
elapsed_time = run_vllm(
requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype,
args.quantization_param_path, args.device,
args.enable_prefix_caching, args.enable_chunked_prefill,
args.max_num_batched_tokens, args.gpu_memory_utilization,
args.download_dir)
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
args.quantization, args.tensor_parallel_size,
args.seed, args.n, args.trust_remote_code,
args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype,
args.quantization_param_path, args.device,
args.enable_prefix_caching,
args.enable_chunked_prefill,
args.max_num_batched_tokens,
args.gpu_memory_utilization, args.download_dir)
else:
raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum(prompt_len + output_len
@ -203,7 +202,6 @@ if __name__ == "__main__":
type=int,
default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--num-prompts",
type=int,
default=200,

View File

@ -391,7 +391,6 @@ async def benchmark(
input_requests: List[Tuple[str, int, int]],
logprobs: Optional[int],
best_of: int,
use_beam_search: bool,
request_rate: float,
disable_tqdm: bool,
profile: bool,
@ -419,7 +418,6 @@ async def benchmark(
output_len=test_output_len,
logprobs=logprobs,
best_of=best_of,
use_beam_search=use_beam_search,
multi_modal_content=test_mm_content,
ignore_eos=ignore_eos,
)
@ -441,7 +439,6 @@ async def benchmark(
output_len=test_output_len,
logprobs=logprobs,
best_of=best_of,
use_beam_search=use_beam_search,
multi_modal_content=test_mm_content,
)
profile_output = await request_func(request_func_input=profile_input)
@ -464,7 +461,6 @@ async def benchmark(
output_len=output_len,
logprobs=logprobs,
best_of=best_of,
use_beam_search=use_beam_search,
multi_modal_content=mm_content,
)
tasks.append(
@ -483,7 +479,6 @@ async def benchmark(
output_len=test_output_len,
logprobs=logprobs,
best_of=best_of,
use_beam_search=use_beam_search,
)
profile_output = await request_func(request_func_input=profile_input)
if profile_output.success:
@ -679,7 +674,6 @@ def main(args: argparse.Namespace):
input_requests=input_requests,
logprobs=args.logprobs,
best_of=args.best_of,
use_beam_search=args.use_beam_search,
request_rate=args.request_rate,
disable_tqdm=args.disable_tqdm,
profile=args.profile,
@ -701,7 +695,6 @@ def main(args: argparse.Namespace):
result_json["model_id"] = model_id
result_json["tokenizer_id"] = tokenizer_id
result_json["best_of"] = args.best_of
result_json["use_beam_search"] = args.use_beam_search
result_json["num_prompts"] = args.num_prompts
# Metadata

View File

@ -73,7 +73,6 @@ def run_vllm(
tensor_parallel_size: int,
seed: int,
n: int,
use_beam_search: bool,
trust_remote_code: bool,
dtype: str,
max_model_len: Optional[int],
@ -91,7 +90,6 @@ def run_vllm(
download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False,
use_new_beam_search_impl: bool = False,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
@ -127,19 +125,19 @@ def run_vllm(
sampling_params.append(
SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
temperature=1.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True,
max_tokens=output_len,
))
if not use_new_beam_search_impl:
use_beam_search = False
if not use_beam_search:
start = time.perf_counter()
llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter()
else:
assert use_beam_search
prompts = [prompt for prompt, _, _ in requests]
# output_len should be the same for all requests.
output_len = requests[0][2]
@ -165,7 +163,6 @@ async def run_vllm_async(
tensor_parallel_size: int,
seed: int,
n: int,
use_beam_search: bool,
trust_remote_code: bool,
dtype: str,
max_model_len: Optional[int],
@ -224,9 +221,8 @@ async def run_vllm_async(
sampling_params.append(
SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
temperature=1.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True,
max_tokens=output_len,
))
@ -248,11 +244,9 @@ def run_hf(
model: str,
tokenizer: PreTrainedTokenizerBase,
n: int,
use_beam_search: bool,
max_batch_size: int,
trust_remote_code: bool,
) -> float:
assert not use_beam_search
llm = AutoModelForCausalLM.from_pretrained(
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
if llm.config.model_type == "llama":
@ -284,7 +278,7 @@ def run_hf(
padding=True).input_ids
llm_outputs = llm.generate(
input_ids=input_ids.cuda(),
do_sample=not use_beam_search,
do_sample=True,
num_return_sequences=n,
temperature=1.0,
top_p=1.0,
@ -340,7 +334,7 @@ def main(args: argparse.Namespace):
if args.backend == "vllm":
run_args = [
requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.tensor_parallel_size, args.seed, args.n,
args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype,
args.quantization_param_path, args.device,
@ -355,12 +349,11 @@ def main(args: argparse.Namespace):
run_args.append(args.disable_frontend_multiprocessing)
elapsed_time = uvloop.run(run_vllm_async(*run_args))
else:
elapsed_time = run_vllm(*run_args, args.use_new_beam_search_impl)
elapsed_time = run_vllm(*run_args)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
args.use_beam_search, args.hf_max_batch_size,
args.trust_remote_code)
args.hf_max_batch_size, args.trust_remote_code)
elif args.backend == "mii":
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
args.output_len)
@ -414,8 +407,6 @@ if __name__ == "__main__":
type=int,
default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--use-new-beam-search-impl", action="store_true")
parser.add_argument("--num-prompts",
type=int,
default=1000,
@ -570,8 +561,6 @@ if __name__ == "__main__":
raise ValueError("dtype must be auto for MII backend.")
if args.n != 1:
raise ValueError("n must be 1 for MII backend.")
if args.use_beam_search:
raise ValueError("Beam search is not supported for MII backend.")
if args.quantization is not None:
raise ValueError("Quantization is only for vLLM backend.")
if args.hf_max_batch_size is not None:

View File

@ -18,9 +18,6 @@ def create_test_prompts() -> List[Tuple[str, SamplingParams]]:
temperature=0.8,
top_p=0.95,
frequency_penalty=0.1)),
("It is only with the heart that one can see rightly",
SamplingParams(n=3, best_of=3, use_beam_search=True,
temperature=0.0)),
]

View File

@ -43,15 +43,6 @@ def create_test_prompts(
max_tokens=128,
stop_token_ids=[32003]),
LoRARequest("sql-lora", 1, lora_path)),
(
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
SamplingParams(n=3,
best_of=3,
use_beam_search=True,
temperature=0,
max_tokens=128,
stop_token_ids=[32003]),
LoRARequest("sql-lora", 1, lora_path)),
(
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
SamplingParams(temperature=0.0,
@ -60,15 +51,6 @@ def create_test_prompts(
max_tokens=128,
stop_token_ids=[32003]),
LoRARequest("sql-lora2", 2, lora_path)),
(
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
SamplingParams(n=3,
best_of=3,
use_beam_search=True,
temperature=0,
max_tokens=128,
stop_token_ids=[32003]),
LoRARequest("sql-lora", 1, lora_path)),
]

View File

@ -23,11 +23,9 @@ MODELS = [
@pytest.fixture(scope="module", autouse=True)
def check_settings():
assert ENABLE_ARTIFICIAL_PREEMPT is True, (
"Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1, "
"VLLM_ALLOW_DEPRECATED_BEAM_SEARCH=1. "
"Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1."
"`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 "
"VLLM_ALLOW_DEPRECATED_BEAM_SEARCH=1 pytest "
"tests/basic_correctness/test_preemption.py`")
"pytest tests/basic_correctness/test_preemption.py`")
@pytest.fixture
@ -137,114 +135,6 @@ def test_preemption(
assert total_preemption == total_recorded_preemption
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
@pytest.mark.parametrize("beam_width", [4])
def test_swap(
caplog_vllm,
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
beam_width: int,
worker_use_ray: bool,
) -> None:
"""Use beam search enables swapping."""
example_prompts = example_prompts[:1]
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width,
max_tokens)
with vllm_runner(
model,
dtype=dtype,
swap_space=10,
disable_log_stats=False,
worker_use_ray=worker_use_ray,
) as vllm_model:
vllm_outputs = vllm_model.generate_beam_search(example_prompts,
beam_width, max_tokens)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
< ARTIFICIAL_PREEMPTION_MAX_CNT)
total_preemption = (
vllm_model.model.llm_engine.scheduler[0].num_cumulative_preemption)
for i in range(len(example_prompts)):
hf_output_ids, _ = hf_outputs[i]
vllm_output_ids, _ = vllm_outputs[i]
assert len(hf_output_ids) == len(vllm_output_ids)
for j in range(len(hf_output_ids)):
assert hf_output_ids[j] == vllm_output_ids[j], (
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
f"vLLM: {vllm_output_ids}")
assert ("is preempted by PreemptionMode.SWAP mode because there "
"is not enough KV cache space." in caplog_vllm.text)
# Ensure the count bucket of request-level histogram metrics matches
# the number of requests as a simple sanity check to ensure metrics are
# generated
preemption_metrics = None
for m in REGISTRY.collect():
if m.name == "vllm:num_preemptions":
preemption_metrics = m
assert preemption_metrics is not None
total_recorded_preemption = 0
for sample in preemption_metrics.samples:
total_recorded_preemption += sample.value
assert total_preemption == total_recorded_preemption
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
@pytest.mark.parametrize("beam_width", [4])
@pytest.mark.parametrize("use_v2_block_manager", [True, False])
def test_swap_infeasible(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
beam_width: int,
worker_use_ray: bool,
use_v2_block_manager: bool,
) -> None:
"""Verify infeasible swap request will be ignored."""
BLOCK_SIZE = 16
prefill_blocks = 2
decode_blocks = max_tokens // BLOCK_SIZE
example_prompts = example_prompts[:1]
with vllm_runner(
model,
dtype=dtype,
swap_space=10,
block_size=BLOCK_SIZE,
# Since beam search have more than 1 sequence, prefill +
# decode blocks are not enough to finish.
num_gpu_blocks_override=prefill_blocks + decode_blocks,
max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE,
worker_use_ray=worker_use_ray,
use_v2_block_manager=use_v2_block_manager,
) as vllm_model:
sampling_params = SamplingParams(n=beam_width,
use_beam_search=True,
temperature=0.0,
max_tokens=max_tokens,
ignore_eos=True)
req_outputs = vllm_model.model.generate(
example_prompts,
sampling_params=sampling_params,
)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
< ARTIFICIAL_PREEMPTION_MAX_CNT)
# Verify the request is ignored and not hang.
assert req_outputs[0].outputs[0].finish_reason == "length"
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])

View File

@ -782,7 +782,6 @@ class VllmRunner:
List[TokensTextLogprobsPromptLogprobs]]:
greedy_logprobs_params = SamplingParams(
temperature=0.0,
use_beam_search=False,
max_tokens=max_tokens,
logprobs=num_logprobs,
prompt_logprobs=(num_prompt_logprobs),
@ -795,19 +794,6 @@ class VllmRunner:
encoder_decoder_prompts, greedy_logprobs_params)
def generate_beam_search(
self,
prompts: List[str],
beam_width: int,
max_tokens: int,
) -> List[Tuple[List[List[int]], List[str]]]:
beam_search_params = SamplingParams(n=beam_width,
use_beam_search=True,
temperature=0.0,
max_tokens=max_tokens)
outputs = self.generate(prompts, beam_search_params)
return outputs
def generate_beam_search_new(
self,
prompts: Union[List[str], List[List[int]]],
beam_width: int,

View File

@ -85,73 +85,6 @@ def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator,
assert baseline_token_ids == test_token_ids
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Use a small model for a fast test.
"model": "facebook/opt-125m",
# skip cuda graph creation for fast test.
"enforce_eager": True,
# Use a large block size to trigger more copy-on-writes.
"block_size": 32,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{
"use_v2_block_manager": False
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"use_v2_block_manager": True,
"preemption_mode": "swap"
}, {
"use_v2_block_manager": True,
"preemption_mode": "recompute"
}])
@pytest.mark.parametrize("batch_size", [10])
@pytest.mark.parametrize("seed", [1])
def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator,
test_llm_generator, batch_size):
"""Verify beam search equality with block manager v1 and v2.
This requires copy-on-writes; if the v1 and v2 output is the same, then
we have some confidence cow is working.
"""
output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
use_beam_search=True,
best_of=2,
)
print('Getting token ids from block manager v1')
baseline_token_ids = get_token_ids_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)
print('Getting token ids from block manager v2')
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
prompts, sampling_params)
for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
test_token_ids):
assert expected_token_ids == actual_token_ids
assert baseline_token_ids == test_token_ids
@pytest.mark.parametrize(
"common_llm_kwargs",
[{

View File

@ -13,7 +13,6 @@ def create_dummy_prompt(
prompt_length: int,
block_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
use_beam_search: bool = False,
best_of: int = 1,
prompt_tokens: Optional[List[int]] = None,
min_tokens: int = 0,
@ -37,7 +36,6 @@ def create_dummy_prompt(
seqs=[prompt],
arrival_time=time.time(),
sampling_params=SamplingParams(
use_beam_search=use_beam_search,
best_of=best_of,
max_tokens=max_tokens,
min_tokens=min_tokens),
@ -52,7 +50,6 @@ def create_dummy_prompt_encoder_decoder(
encoder_prompt_length: int,
block_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
use_beam_search: bool = False,
best_of: int = 1,
) -> Tuple[Sequence, Sequence, SequenceGroup]:
if not block_size:
@ -85,9 +82,7 @@ def create_dummy_prompt_encoder_decoder(
from_decoder_prompt=False)
seq_group = SequenceGroup(request_id=request_id,
seqs=[decoder_prompt],
sampling_params=SamplingParams(
use_beam_search=use_beam_search,
best_of=best_of),
sampling_params=SamplingParams(best_of=best_of),
arrival_time=time.time(),
lora_request=lora_request,
encoder_seq=encoder_prompt)

View File

@ -33,8 +33,8 @@ def test_beam_search_single_input(
max_tokens)
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_beam_search_new(
example_prompts, beam_width, max_tokens)
vllm_outputs = vllm_model.generate_beam_search(example_prompts,
beam_width, max_tokens)
for i in range(len(example_prompts)):
hf_output_ids, hf_output_texts = hf_outputs[i]

View File

@ -159,26 +159,6 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str):
assert first_sampler_output == second_sampler_output
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_beam(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
_, fake_logits, sampler = _prepare_test(batch_size)
sampling_params = SamplingParams(
temperature=0,
best_of=2,
use_beam_search=True,
)
_do_sample(batch_size, fake_logits, sampler, sampling_params, device)
# no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
# when handling an all-beam search case.
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_min_tokens_penalty(seed: int, device: str):
@ -479,7 +459,7 @@ def test_sampler_mixed(seed: int, device: str):
seq_lens: List[int] = []
for i in range(batch_size):
expected: Optional[List[int]] = None
sampling_type = random.randint(0, 3)
sampling_type = random.randint(0, 2)
if sampling_type == 0:
sampling_params = SamplingParams(temperature=0)
expected = [int(torch.argmax(fake_logits[i], dim=-1).item())]
@ -498,10 +478,7 @@ def test_sampler_mixed(seed: int, device: str):
for idx in range(n):
fake_logits[i, i + idx] = 1e2
expected = list(range(i, i + n))
else:
sampling_params = SamplingParams(temperature=0,
use_beam_search=True,
best_of=2)
expected_tokens.append(expected)
seq_group_metadata_list.append(
SequenceGroupMetadata(
@ -530,9 +507,6 @@ def test_sampler_mixed(seed: int, device: str):
zip(sampler_output, seq_group_metadata_list)):
assert metadata.sampling_params is not None
if metadata.sampling_params.use_beam_search:
continue
if (metadata.sampling_params.seed is not None
and expected_tokens[i] is None):
# Record seeded random result to compare with results of

View File

@ -1202,9 +1202,9 @@ class Scheduler:
seq_group=seq_group, num_lookahead_slots=num_lookahead_slots)
def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool:
# TODO: does it work with parallel sampling?
no_beam_search = seq_group.sampling_params is None or (
seq_group.sampling_params.best_of == 1
and not seq_group.sampling_params.use_beam_search)
seq_group.sampling_params.best_of == 1)
return no_beam_search
def schedule(

View File

@ -33,7 +33,7 @@ from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (collect_from_async_generator, deprecate_kwargs,
random_uuid, weak_bind)
get_beam_search_score, random_uuid, weak_bind)
logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
@ -1050,6 +1050,12 @@ class AsyncLLMEngine:
max_tokens = params.max_tokens
ignore_eos = params.ignore_eos
temperature = params.temperature
length_penalty = params.length_penalty
def sort_beams_key(x: BeamSearchSequence) -> float:
return get_beam_search_score(x.tokens, x.cum_logprob,
tokenizer.eos_token_id,
length_penalty)
tokenizer = await self.get_tokenizer()
tokenizedPrompt = prompt if isinstance(
@ -1103,15 +1109,11 @@ class AsyncLLMEngine:
else:
new_beams.append(new_beam)
sorted_beams = sorted(new_beams,
key=lambda x: x.cum_logprob,
reverse=True)
sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
all_beams = sorted_beams[:beam_width]
completed.extend(all_beams)
sorted_completed = sorted(completed,
key=lambda x: x.cum_logprob,
reverse=True)
sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
best_beams = sorted_completed[:beam_width]
for beam in best_beams:

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Tuple
from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler
@ -6,7 +6,6 @@ from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
@ -113,7 +112,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
outputs: SequenceGroupOutput,
is_async: bool) -> None:
sampling_params = seq_group.sampling_params
if sampling_params.best_of == 1 and not sampling_params.use_beam_search:
if sampling_params.best_of == 1:
# only have one output sample
sample = outputs.samples[0]
# only have one sequence
@ -142,7 +141,6 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
# Process samples
samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
existing_finished_seqs = seq_group.get_finished_seqs()
parent_child_dict: Dict[int, List[SequenceOutput]] = {
parent_seq.seq_id: []
for parent_seq in parent_seqs
@ -197,106 +195,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
lora_req=seq_group.lora_request,
)
# Non-beam search case
if not sampling_params.use_beam_search:
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in child_seqs:
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
for scheduler in self.scheduler:
scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
# NOTE: we need to fork the new sequences before freeing the
# old sequences.
for seq, parent in child_seqs:
if seq is parent and seq.is_finished():
for scheduler in self.scheduler:
scheduler.free_seq(seq)
return
# Beam search case
# Select the child sequences to keep in the sequence group.
selected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = []
unselected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = []
beam_width = sampling_params.best_of
length_penalty = sampling_params.length_penalty
# Select the newly finished sequences with the highest scores
# to replace existing finished sequences.
# Tuple of (seq, parent, is_new)
existing_finished_seqs = [(seq, None, False)
for seq in existing_finished_seqs]
new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
if seq.is_finished()]
all_finished_seqs = existing_finished_seqs + new_finished_seqs
# Sort the finished sequences by their scores.
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
reverse=True)
for seq, parent, is_new in all_finished_seqs[:beam_width]:
if is_new:
# A newly generated child sequence finishes and has a high
# score, so we will add it into the sequence group.
selected_child_seqs.append((seq, parent))
for seq, parent, is_new in all_finished_seqs[beam_width:]:
if is_new:
# A newly generated child sequence finishes but has a low
# score, so we will not add it into the sequence group.
# Additionally, if this sequence is a continuation of a
# parent sequence, we will need remove the parent sequence
# from the sequence group.
unselected_child_seqs.append((seq, parent))
else:
# An existing finished sequence has a low score, so we will
# remove it from the sequence group.
seq_group.remove(seq.seq_id)
# select the top beam_width sequences from the running
# sequences for the next iteration to continue the beam
# search.
running_child_seqs = [(seq, parent) for seq, parent in child_seqs
if not seq.is_finished()]
# Sort the running sequences by their scores.
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
reverse=True)
# Check if we can stop the beam search.
if len(running_child_seqs) == 0:
# No running sequences, stop the beam search.
stop_beam_search = True
elif len(all_finished_seqs) < beam_width:
# Not enough finished sequences, continue the beam search.
stop_beam_search = False
else:
# Check the early stopping criteria
best_running_seq = running_child_seqs[0][0]
current_worst_seq = all_finished_seqs[beam_width - 1][0]
stop_beam_search = self._check_beam_search_early_stopping(
sampling_params.early_stopping, sampling_params,
best_running_seq, current_worst_seq)
if stop_beam_search:
# Stop the beam search and remove all the running sequences from
# the sequence group.
unselected_child_seqs.extend(running_child_seqs)
else:
# Continue the beam search and select the top beam_width sequences
# to continue the beam search.
selected_child_seqs.extend(running_child_seqs[:beam_width])
# The remaining running sequences will not be used in the next
# iteration. Again, if these sequences are continuations of
# parent sequences, we will need to remove the parent sequences
# from the sequence group.
unselected_child_seqs.extend(running_child_seqs[beam_width:])
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in selected_child_seqs:
for seq, parent in child_seqs:
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
@ -305,61 +206,10 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
for seq, parent in selected_child_seqs:
# NOTE: we need to fork the new sequences before freeing the
# old sequences.
for seq, parent in child_seqs:
if seq is parent and seq.is_finished():
for scheduler in self.scheduler:
scheduler.free_seq(seq)
# Remove the unselected parent sequences from the sequence group and
# free their memory in block manager.
for seq, parent in unselected_child_seqs:
if seq is parent:
# Remove the parent sequence if it is not selected for next
# iteration
seq_group.remove(seq.seq_id)
for scheduler in self.scheduler:
scheduler.free_seq(seq)
def _check_beam_search_early_stopping(
self,
early_stopping: Union[bool, str],
sampling_params: SamplingParams,
best_running_seq: Sequence,
current_worst_seq: Sequence,
) -> bool:
assert sampling_params.use_beam_search
length_penalty = sampling_params.length_penalty
if early_stopping is True:
return True
current_worst_score = current_worst_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=current_worst_seq.eos_token_id)
if early_stopping is False:
highest_attainable_score = best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=best_running_seq.eos_token_id)
else:
assert early_stopping == "never"
if length_penalty > 0.0:
# If length_penalty > 0.0, beam search will prefer longer
# sequences. The highest attainable score calculation is
# based on the longest possible sequence length in this case.
max_possible_length = max(
best_running_seq.get_prompt_len() +
sampling_params.max_tokens,
self.scheduler_config.max_model_len)
highest_attainable_score = (
best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=best_running_seq.eos_token_id,
seq_len=max_possible_length))
else:
# Otherwise, beam search will prefer shorter sequences. The
# highest attainable score calculation is based on the current
# sequence length.
highest_attainable_score = (
best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=best_running_seq.eos_token_id))
return current_worst_score >= highest_attainable_score
return

View File

@ -28,7 +28,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_kwargs, is_list_of
from vllm.utils import (Counter, deprecate_kwargs, get_beam_search_score,
is_list_of)
logger = init_logger(__name__)
@ -404,6 +405,12 @@ class LLM:
max_tokens = params.max_tokens
temperature = params.temperature
ignore_eos = params.ignore_eos
length_penalty = params.length_penalty
def sort_beams_key(x: BeamSearchSequence) -> float:
return get_beam_search_score(x.tokens, x.cum_logprob,
tokenizer.eos_token_id,
length_penalty)
tokenizer = self.get_tokenizer()
# generate 2 * beam_width candidates at each step
@ -466,7 +473,7 @@ class LLM:
else:
instance_new_beams.append(new_beam)
sorted_beams = sorted(instance_new_beams,
key=lambda x: x.cum_logprob,
key=sort_beams_key,
reverse=True)
instance.beams = sorted_beams[:beam_width]
@ -474,7 +481,7 @@ class LLM:
for instance in instances:
instance.completed.extend(instance.beams)
sorted_completed = sorted(instance.completed,
key=lambda x: x.cum_logprob,
key=sort_beams_key,
reverse=True)
best_beams = sorted_completed[:beam_width]

View File

@ -184,7 +184,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
min_p: float = 0.0
repetition_penalty: float = 1.0
length_penalty: float = 1.0
early_stopping: bool = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
include_stop_str_in_output: bool = False
ignore_eos: bool = False
@ -302,6 +301,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
max_tokens=max_tokens,
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
)
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
@ -345,12 +345,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
ignore_eos=self.ignore_eos,
max_tokens=max_tokens,
min_tokens=self.min_tokens,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,
@ -518,7 +515,6 @@ class CompletionRequest(OpenAIBaseModel):
min_p: float = 0.0
repetition_penalty: float = 1.0
length_penalty: float = 1.0
early_stopping: bool = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
include_stop_str_in_output: bool = False
ignore_eos: bool = False
@ -597,6 +593,7 @@ class CompletionRequest(OpenAIBaseModel):
max_tokens=max_tokens,
ignore_eos=self.ignore_eos,
temperature=temperature,
length_penalty=self.length_penalty,
)
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
@ -641,13 +638,10 @@ class CompletionRequest(OpenAIBaseModel):
ignore_eos=self.ignore_eos,
max_tokens=max_tokens if not echo_without_generation else 1,
min_tokens=self.min_tokens,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
prompt_logprobs=prompt_logprobs,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY,

View File

@ -63,7 +63,6 @@ if TYPE_CHECKING:
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_ALLOW_DEPRECATED_BEAM_SEARCH: bool = False
VLLM_SKIP_P2P_CHECK: bool = False
@ -198,10 +197,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
("true", "1")),
# If set, allowing the use of deprecated beam search implementation
"VLLM_ALLOW_DEPRECATED_BEAM_SEARCH":
lambda: os.environ.get("VLLM_ALLOW_DEPRECATED_BEAM_SEARCH", "0") == "1",
# Internal flag to enable Dynamo graph capture
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE":
lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")),

View File

@ -947,8 +947,6 @@ def get_logprobs(
# largest num logprobs in this API. If every logprobs is None, it will be
# set to -1.
largest_num_logprobs = -1
# If beam search is enabled.
use_beam_search = False
# Select indices to compute logprob from, ranks of token ids, and the top
# k token ids from logprobs.
@ -981,8 +979,6 @@ def get_logprobs(
largest_num_logprobs = max(largest_num_logprobs,
sampling_params.logprobs)
use_beam_search = use_beam_search or sampling_params.use_beam_search
assert len(next_token_ids) == len(query_indices)
if len(query_indices) == 0:
@ -995,7 +991,7 @@ def get_logprobs(
# If largest_num_logprobs == -1, i.e. no logprobs are requested, we can
# skip the whole logprob calculation.
if largest_num_logprobs >= 0 or use_beam_search:
if largest_num_logprobs >= 0:
query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
next_token_ids_gpu = torch.tensor(next_token_ids,
device=logprobs.device)
@ -1121,13 +1117,12 @@ def _get_sampled_logprob_if_needed(
"""Compute the sample logprob if needed."""
seq_ids = seq_group.seq_ids
num_logprobs = seq_group.sampling_params.logprobs
use_beam_search = seq_group.sampling_params.use_beam_search
sampled_logprobs: SampleLogprobs = []
next_token_ids, parent_seq_ids = sample_result
if seq_group.do_sample:
assert len(next_token_ids) > 0
if num_logprobs is None and not use_beam_search:
if num_logprobs is None:
for next_token_id in next_token_ids:
# Use a dummy logprob
sampled_logprobs.append({next_token_id: Logprob(inf)})

View File

@ -142,11 +142,7 @@ class RequestOutput:
else:
# Get the top-n sequences.
n = sampling_params.n
if sampling_params.use_beam_search:
sorting_key = lambda seq: seq.get_beam_search_score(
sampling_params.length_penalty)
else:
sorting_key = lambda seq: seq.get_cumulative_logprob()
sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
top_n_seqs = sorted_seqs[:n]

View File

@ -10,7 +10,6 @@ import torch
from pydantic import BaseModel
from typing_extensions import Annotated
import vllm.envs as envs
from vllm.logger import init_logger
logger = init_logger(__name__)
@ -23,7 +22,6 @@ class SamplingType(IntEnum):
GREEDY = 0
RANDOM = 1
RANDOM_SEED = 2
BEAM = 3
LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor],
@ -134,16 +132,6 @@ class SamplingParams(
considered, relative to the probability of the most likely token.
Must be in [0, 1]. Set to 0 to disable this.
seed: Random seed to use for the generation.
use_beam_search: Whether to use beam search instead of sampling.
length_penalty: Float that penalizes sequences based on their length.
Used in beam search.
early_stopping: Controls the stopping condition for beam search. It
accepts the following values: `True`, where the generation stops as
soon as there are `best_of` complete candidates; `False`, where an
heuristic is applied and the generation stops when is it very
unlikely to find better candidates; `"never"`, where the beam search
procedure only stops when there cannot be better candidates
(canonical beam search algorithm).
stop: List of strings that stop the generation when they are generated.
The returned output will not contain the stop strings.
stop_token_ids: List of tokens that stop the generation when they are
@ -193,9 +181,6 @@ class SamplingParams(
top_k: int = -1
min_p: float = 0.0
seed: Optional[int] = None
use_beam_search: bool = False
length_penalty: float = 1.0
early_stopping: Union[bool, str] = False
stop: Optional[Union[str, List[str]]] = None
stop_token_ids: Optional[List[int]] = None
ignore_eos: bool = False
@ -238,9 +223,6 @@ class SamplingParams(
top_k: int = -1,
min_p: float = 0.0,
seed: Optional[int] = None,
use_beam_search: bool = False,
length_penalty: float = 1.0,
early_stopping: Union[bool, str] = False,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
include_stop_str_in_output: bool = False,
@ -280,9 +262,6 @@ class SamplingParams(
top_k=top_k,
min_p=min_p,
seed=seed,
use_beam_search=use_beam_search,
length_penalty=length_penalty,
early_stopping=early_stopping,
stop=stop,
stop_token_ids=stop_token_ids,
include_stop_str_in_output=include_stop_str_in_output,
@ -334,20 +313,13 @@ class SamplingParams(
self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
self._verify_args()
if self.use_beam_search:
if not envs.VLLM_ALLOW_DEPRECATED_BEAM_SEARCH:
raise ValueError(
"Using beam search as a sampling parameter is deprecated, and will be removed in the future release. Please use the `vllm.LLM.use_beam_search` method for dedicated beam search instead, or set the environment variable `VLLM_ALLOW_DEPRECATED_BEAM_SEARCH=1` to suppress this error. For more details, see https://github.com/vllm-project/vllm/issues/8306 ." # noqa
)
self._verify_beam_search()
else:
self._verify_non_beam_search()
if self.temperature < _SAMPLING_EPS:
# Zero temperature means greedy sampling.
self.top_p = 1.0
self.top_k = -1
self.min_p = 0.0
self._verify_greedy_sampling()
if self.temperature < _SAMPLING_EPS:
# Zero temperature means greedy sampling.
self.top_p = 1.0
self.top_k = -1
self.min_p = 0.0
self._verify_greedy_sampling()
# eos_token_id is added to this by the engine
self._all_stop_token_ids = set(self.stop_token_ids)
@ -417,31 +389,6 @@ class SamplingParams(
RequestOutputKind.DELTA):
raise ValueError("best_of must equal n to use output_kind=DELTA")
def _verify_beam_search(self) -> None:
if self.best_of == 1:
raise ValueError("best_of must be greater than 1 when using beam "
f"search. Got {self.best_of}.")
if self.temperature > _SAMPLING_EPS:
raise ValueError("temperature must be 0 when using beam search.")
if self.top_p < 1.0 - _SAMPLING_EPS:
raise ValueError("top_p must be 1 when using beam search.")
if self.top_k != -1:
raise ValueError("top_k must be -1 when using beam search.")
if self.early_stopping not in [True, False, "never"]:
raise ValueError(
f"early_stopping must be True, False, or 'never', "
f"got {self.early_stopping}.")
def _verify_non_beam_search(self) -> None:
if self.early_stopping is not False:
raise ValueError("early_stopping is not effective and must be "
"False when not using beam search.")
if (self.length_penalty < 1.0 - _SAMPLING_EPS
or self.length_penalty > 1.0 + _SAMPLING_EPS):
raise ValueError(
"length_penalty is not effective and must be the "
"default value of 1.0 when not using beam search.")
def _verify_greedy_sampling(self) -> None:
assert isinstance(self.best_of, int)
if self.best_of > 1:
@ -476,8 +423,6 @@ class SamplingParams(
@cached_property
def sampling_type(self) -> SamplingType:
if self.use_beam_search:
return SamplingType.BEAM
if self.temperature < _SAMPLING_EPS:
return SamplingType.GREEDY
if self.seed is not None:
@ -514,9 +459,6 @@ class SamplingParams(
f"top_k={self.top_k}, "
f"min_p={self.min_p}, "
f"seed={self.seed}, "
f"use_beam_search={self.use_beam_search}, "
f"length_penalty={self.length_penalty}, "
f"early_stopping={self.early_stopping}, "
f"stop={self.stop}, "
f"stop_token_ids={self.stop_token_ids}, "
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
@ -542,3 +484,4 @@ class BeamSearchParams(
max_tokens: int
ignore_eos: bool = False
temperature: float = 0.0
length_penalty: float = 1.0

View File

@ -577,25 +577,6 @@ class Sequence:
def get_cumulative_logprob(self) -> float:
return self.data.cumulative_logprob
def get_beam_search_score(self,
length_penalty: float = 1.0,
seq_len: Optional[int] = None,
eos_token_id: Optional[int] = None) -> float:
"""Calculate the beam search score with length penalty.
Adapted from
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
"""
if seq_len is None:
seq_len = self.get_len()
# NOTE: HF implementation does not count the EOS token
# towards the length, we align with that here for testing.
if (eos_token_id is not None
and self.get_last_token_id() == eos_token_id):
seq_len -= 1
return self.get_cumulative_logprob() / (seq_len**length_penalty)
def is_finished(self) -> bool:
return SequenceStatus.is_finished(self.status)
@ -809,25 +790,18 @@ class SequenceGroup:
def get_max_num_running_seqs(self) -> int:
"""The maximum number of sequences running in parallel in the remaining
lifetime of the request."""
if self.sampling_params and self.sampling_params.use_beam_search:
# For beam search, maximally there will always be `best_of` beam
# candidates running in the future.
if self.sampling_params:
best_of = self.sampling_params.best_of
assert isinstance(best_of, int)
return best_of
else:
if self.sampling_params:
best_of = self.sampling_params.best_of
assert isinstance(best_of, int)
if best_of > self.num_seqs():
# At prompt stage, the sequence group is not yet filled up
# and only have one sequence running. However, in the
# generation stage, we will have `best_of` sequences
# running.
return best_of
# At sampling stages, return the number of actual sequences
# that are not finished yet.
return self.num_unfinished_seqs()
if best_of > self.num_seqs():
# At prompt stage, the sequence group is not yet filled up
# and only have one sequence running. However, in the
# generation stage, we will have `best_of` sequences
# running.
return best_of
# At sampling stages, return the number of actual sequences
# that are not finished yet.
return self.num_unfinished_seqs()
def get_seqs(
self,

View File

@ -1361,3 +1361,22 @@ class AtomicCounter:
@property
def value(self):
return self._value
def get_beam_search_score(
tokens: List[int],
cumulative_logprob: float,
eos_token_id: int,
length_penalty: float = 1.0,
) -> float:
"""Calculate the beam search score with length penalty.
Adapted from
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
"""
seq_len = len(tokens)
if tokens[-1] == eos_token_id:
seq_len -= 1
return cumulative_logprob / (seq_len**length_penalty)

View File

@ -453,9 +453,6 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
"backend.")
best_of.append(sampling_params.best_of)
if sampling_params.use_beam_search:
raise NotImplementedError(
"Beam search is not supported by the TPU backend.")
if sampling_params.logprobs is not None:
raise NotImplementedError(
"logprobs is not currently supported by the TPU backend.")