[core] remove beam search from the core (#9105)

This commit is contained in:
youkaichao 2024-10-06 22:47:04 -07:00 committed by GitHub
parent c8f26bb636
commit 18b296fdb2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 98 additions and 596 deletions

View File

@ -23,7 +23,6 @@ class RequestFuncInput:
output_len: int output_len: int
model: str model: str
best_of: int = 1 best_of: int = 1
use_beam_search: bool = False
logprobs: Optional[int] = None logprobs: Optional[int] = None
multi_modal_content: Optional[dict] = None multi_modal_content: Optional[dict] = None
ignore_eos: bool = False ignore_eos: bool = False
@ -49,7 +48,6 @@ async def async_request_tgi(
assert api_url.endswith("generate_stream") assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
params = { params = {
"best_of": request_func_input.best_of, "best_of": request_func_input.best_of,
"max_new_tokens": request_func_input.output_len, "max_new_tokens": request_func_input.output_len,
@ -121,7 +119,6 @@ async def async_request_trt_llm(
assert api_url.endswith("generate_stream") assert api_url.endswith("generate_stream")
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
assert request_func_input.best_of == 1 assert request_func_input.best_of == 1
payload = { payload = {
"accumulate_tokens": True, "accumulate_tokens": True,
@ -187,7 +184,6 @@ async def async_request_deepspeed_mii(
) -> RequestFuncOutput: ) -> RequestFuncOutput:
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert request_func_input.best_of == 1 assert request_func_input.best_of == 1
assert not request_func_input.use_beam_search
payload = { payload = {
"prompt": request_func_input.prompt, "prompt": request_func_input.prompt,
@ -235,7 +231,6 @@ async def async_request_openai_completions(
), "OpenAI Completions API URL must end with 'completions' or 'profile'." ), "OpenAI Completions API URL must end with 'completions' or 'profile'."
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
payload = { payload = {
"model": request_func_input.model, "model": request_func_input.model,
"prompt": request_func_input.prompt, "prompt": request_func_input.prompt,
@ -317,7 +312,6 @@ async def async_request_openai_chat_completions(
), "OpenAI Chat Completions API URL must end with 'chat/completions'." ), "OpenAI Chat Completions API URL must end with 'chat/completions'."
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
assert not request_func_input.use_beam_search
content = [{"type": "text", "text": request_func_input.prompt}] content = [{"type": "text", "text": request_func_input.prompt}]
if request_func_input.multi_modal_content: if request_func_input.multi_modal_content:
content.append(request_func_input.multi_modal_content) content.append(request_func_input.multi_modal_content)

View File

@ -51,9 +51,8 @@ def main(args: argparse.Namespace):
sampling_params = SamplingParams( sampling_params = SamplingParams(
n=args.n, n=args.n,
temperature=0.0 if args.use_beam_search else 1.0, temperature=1.0,
top_p=1.0, top_p=1.0,
use_beam_search=args.use_beam_search,
ignore_eos=True, ignore_eos=True,
max_tokens=args.output_len, max_tokens=args.output_len,
) )

View File

@ -68,7 +68,6 @@ def run_vllm(
tensor_parallel_size: int, tensor_parallel_size: int,
seed: int, seed: int,
n: int, n: int,
use_beam_search: bool,
trust_remote_code: bool, trust_remote_code: bool,
dtype: str, dtype: str,
max_model_len: Optional[int], max_model_len: Optional[int],
@ -114,9 +113,8 @@ def run_vllm(
sampling_params.append( sampling_params.append(
SamplingParams( SamplingParams(
n=n, n=n,
temperature=0.0 if use_beam_search else 1.0, temperature=1.0,
top_p=1.0, top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True, ignore_eos=True,
max_tokens=output_len, max_tokens=output_len,
)) ))
@ -144,15 +142,16 @@ def main(args: argparse.Namespace):
args.output_len) args.output_len)
if args.backend == "vllm": if args.backend == "vllm":
elapsed_time = run_vllm( elapsed_time = run_vllm(requests, args.model, args.tokenizer,
requests, args.model, args.tokenizer, args.quantization, args.quantization, args.tensor_parallel_size,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, args.seed, args.n, args.trust_remote_code,
args.trust_remote_code, args.dtype, args.max_model_len, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype, args.enforce_eager, args.kv_cache_dtype,
args.quantization_param_path, args.device, args.quantization_param_path, args.device,
args.enable_prefix_caching, args.enable_chunked_prefill, args.enable_prefix_caching,
args.max_num_batched_tokens, args.gpu_memory_utilization, args.enable_chunked_prefill,
args.download_dir) args.max_num_batched_tokens,
args.gpu_memory_utilization, args.download_dir)
else: else:
raise ValueError(f"Unknown backend: {args.backend}") raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum(prompt_len + output_len total_num_tokens = sum(prompt_len + output_len
@ -203,7 +202,6 @@ if __name__ == "__main__":
type=int, type=int,
default=1, default=1,
help="Number of generated sequences per prompt.") help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--num-prompts", parser.add_argument("--num-prompts",
type=int, type=int,
default=200, default=200,

View File

@ -391,7 +391,6 @@ async def benchmark(
input_requests: List[Tuple[str, int, int]], input_requests: List[Tuple[str, int, int]],
logprobs: Optional[int], logprobs: Optional[int],
best_of: int, best_of: int,
use_beam_search: bool,
request_rate: float, request_rate: float,
disable_tqdm: bool, disable_tqdm: bool,
profile: bool, profile: bool,
@ -419,7 +418,6 @@ async def benchmark(
output_len=test_output_len, output_len=test_output_len,
logprobs=logprobs, logprobs=logprobs,
best_of=best_of, best_of=best_of,
use_beam_search=use_beam_search,
multi_modal_content=test_mm_content, multi_modal_content=test_mm_content,
ignore_eos=ignore_eos, ignore_eos=ignore_eos,
) )
@ -441,7 +439,6 @@ async def benchmark(
output_len=test_output_len, output_len=test_output_len,
logprobs=logprobs, logprobs=logprobs,
best_of=best_of, best_of=best_of,
use_beam_search=use_beam_search,
multi_modal_content=test_mm_content, multi_modal_content=test_mm_content,
) )
profile_output = await request_func(request_func_input=profile_input) profile_output = await request_func(request_func_input=profile_input)
@ -464,7 +461,6 @@ async def benchmark(
output_len=output_len, output_len=output_len,
logprobs=logprobs, logprobs=logprobs,
best_of=best_of, best_of=best_of,
use_beam_search=use_beam_search,
multi_modal_content=mm_content, multi_modal_content=mm_content,
) )
tasks.append( tasks.append(
@ -483,7 +479,6 @@ async def benchmark(
output_len=test_output_len, output_len=test_output_len,
logprobs=logprobs, logprobs=logprobs,
best_of=best_of, best_of=best_of,
use_beam_search=use_beam_search,
) )
profile_output = await request_func(request_func_input=profile_input) profile_output = await request_func(request_func_input=profile_input)
if profile_output.success: if profile_output.success:
@ -679,7 +674,6 @@ def main(args: argparse.Namespace):
input_requests=input_requests, input_requests=input_requests,
logprobs=args.logprobs, logprobs=args.logprobs,
best_of=args.best_of, best_of=args.best_of,
use_beam_search=args.use_beam_search,
request_rate=args.request_rate, request_rate=args.request_rate,
disable_tqdm=args.disable_tqdm, disable_tqdm=args.disable_tqdm,
profile=args.profile, profile=args.profile,
@ -701,7 +695,6 @@ def main(args: argparse.Namespace):
result_json["model_id"] = model_id result_json["model_id"] = model_id
result_json["tokenizer_id"] = tokenizer_id result_json["tokenizer_id"] = tokenizer_id
result_json["best_of"] = args.best_of result_json["best_of"] = args.best_of
result_json["use_beam_search"] = args.use_beam_search
result_json["num_prompts"] = args.num_prompts result_json["num_prompts"] = args.num_prompts
# Metadata # Metadata

View File

@ -73,7 +73,6 @@ def run_vllm(
tensor_parallel_size: int, tensor_parallel_size: int,
seed: int, seed: int,
n: int, n: int,
use_beam_search: bool,
trust_remote_code: bool, trust_remote_code: bool,
dtype: str, dtype: str,
max_model_len: Optional[int], max_model_len: Optional[int],
@ -91,7 +90,6 @@ def run_vllm(
download_dir: Optional[str] = None, download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format, load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False, disable_async_output_proc: bool = False,
use_new_beam_search_impl: bool = False,
) -> float: ) -> float:
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
llm = LLM( llm = LLM(
@ -127,19 +125,19 @@ def run_vllm(
sampling_params.append( sampling_params.append(
SamplingParams( SamplingParams(
n=n, n=n,
temperature=0.0 if use_beam_search else 1.0, temperature=1.0,
top_p=1.0, top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True, ignore_eos=True,
max_tokens=output_len, max_tokens=output_len,
)) ))
if not use_new_beam_search_impl: use_beam_search = False
if not use_beam_search:
start = time.perf_counter() start = time.perf_counter()
llm.generate(prompts, sampling_params, use_tqdm=True) llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter() end = time.perf_counter()
else: else:
assert use_beam_search
prompts = [prompt for prompt, _, _ in requests] prompts = [prompt for prompt, _, _ in requests]
# output_len should be the same for all requests. # output_len should be the same for all requests.
output_len = requests[0][2] output_len = requests[0][2]
@ -165,7 +163,6 @@ async def run_vllm_async(
tensor_parallel_size: int, tensor_parallel_size: int,
seed: int, seed: int,
n: int, n: int,
use_beam_search: bool,
trust_remote_code: bool, trust_remote_code: bool,
dtype: str, dtype: str,
max_model_len: Optional[int], max_model_len: Optional[int],
@ -224,9 +221,8 @@ async def run_vllm_async(
sampling_params.append( sampling_params.append(
SamplingParams( SamplingParams(
n=n, n=n,
temperature=0.0 if use_beam_search else 1.0, temperature=1.0,
top_p=1.0, top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True, ignore_eos=True,
max_tokens=output_len, max_tokens=output_len,
)) ))
@ -248,11 +244,9 @@ def run_hf(
model: str, model: str,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
n: int, n: int,
use_beam_search: bool,
max_batch_size: int, max_batch_size: int,
trust_remote_code: bool, trust_remote_code: bool,
) -> float: ) -> float:
assert not use_beam_search
llm = AutoModelForCausalLM.from_pretrained( llm = AutoModelForCausalLM.from_pretrained(
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
if llm.config.model_type == "llama": if llm.config.model_type == "llama":
@ -284,7 +278,7 @@ def run_hf(
padding=True).input_ids padding=True).input_ids
llm_outputs = llm.generate( llm_outputs = llm.generate(
input_ids=input_ids.cuda(), input_ids=input_ids.cuda(),
do_sample=not use_beam_search, do_sample=True,
num_return_sequences=n, num_return_sequences=n,
temperature=1.0, temperature=1.0,
top_p=1.0, top_p=1.0,
@ -340,7 +334,7 @@ def main(args: argparse.Namespace):
if args.backend == "vllm": if args.backend == "vllm":
run_args = [ run_args = [
requests, args.model, args.tokenizer, args.quantization, requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, args.tensor_parallel_size, args.seed, args.n,
args.trust_remote_code, args.dtype, args.max_model_len, args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype, args.enforce_eager, args.kv_cache_dtype,
args.quantization_param_path, args.device, args.quantization_param_path, args.device,
@ -355,12 +349,11 @@ def main(args: argparse.Namespace):
run_args.append(args.disable_frontend_multiprocessing) run_args.append(args.disable_frontend_multiprocessing)
elapsed_time = uvloop.run(run_vllm_async(*run_args)) elapsed_time = uvloop.run(run_vllm_async(*run_args))
else: else:
elapsed_time = run_vllm(*run_args, args.use_new_beam_search_impl) elapsed_time = run_vllm(*run_args)
elif args.backend == "hf": elif args.backend == "hf":
assert args.tensor_parallel_size == 1 assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n, elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
args.use_beam_search, args.hf_max_batch_size, args.hf_max_batch_size, args.trust_remote_code)
args.trust_remote_code)
elif args.backend == "mii": elif args.backend == "mii":
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size, elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
args.output_len) args.output_len)
@ -414,8 +407,6 @@ if __name__ == "__main__":
type=int, type=int,
default=1, default=1,
help="Number of generated sequences per prompt.") help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--use-new-beam-search-impl", action="store_true")
parser.add_argument("--num-prompts", parser.add_argument("--num-prompts",
type=int, type=int,
default=1000, default=1000,
@ -570,8 +561,6 @@ if __name__ == "__main__":
raise ValueError("dtype must be auto for MII backend.") raise ValueError("dtype must be auto for MII backend.")
if args.n != 1: if args.n != 1:
raise ValueError("n must be 1 for MII backend.") raise ValueError("n must be 1 for MII backend.")
if args.use_beam_search:
raise ValueError("Beam search is not supported for MII backend.")
if args.quantization is not None: if args.quantization is not None:
raise ValueError("Quantization is only for vLLM backend.") raise ValueError("Quantization is only for vLLM backend.")
if args.hf_max_batch_size is not None: if args.hf_max_batch_size is not None:

View File

@ -18,9 +18,6 @@ def create_test_prompts() -> List[Tuple[str, SamplingParams]]:
temperature=0.8, temperature=0.8,
top_p=0.95, top_p=0.95,
frequency_penalty=0.1)), frequency_penalty=0.1)),
("It is only with the heart that one can see rightly",
SamplingParams(n=3, best_of=3, use_beam_search=True,
temperature=0.0)),
] ]

View File

@ -43,15 +43,6 @@ def create_test_prompts(
max_tokens=128, max_tokens=128,
stop_token_ids=[32003]), stop_token_ids=[32003]),
LoRARequest("sql-lora", 1, lora_path)), LoRARequest("sql-lora", 1, lora_path)),
(
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
SamplingParams(n=3,
best_of=3,
use_beam_search=True,
temperature=0,
max_tokens=128,
stop_token_ids=[32003]),
LoRARequest("sql-lora", 1, lora_path)),
( (
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
SamplingParams(temperature=0.0, SamplingParams(temperature=0.0,
@ -60,15 +51,6 @@ def create_test_prompts(
max_tokens=128, max_tokens=128,
stop_token_ids=[32003]), stop_token_ids=[32003]),
LoRARequest("sql-lora2", 2, lora_path)), LoRARequest("sql-lora2", 2, lora_path)),
(
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
SamplingParams(n=3,
best_of=3,
use_beam_search=True,
temperature=0,
max_tokens=128,
stop_token_ids=[32003]),
LoRARequest("sql-lora", 1, lora_path)),
] ]

View File

@ -23,11 +23,9 @@ MODELS = [
@pytest.fixture(scope="module", autouse=True) @pytest.fixture(scope="module", autouse=True)
def check_settings(): def check_settings():
assert ENABLE_ARTIFICIAL_PREEMPT is True, ( assert ENABLE_ARTIFICIAL_PREEMPT is True, (
"Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1, " "Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1."
"VLLM_ALLOW_DEPRECATED_BEAM_SEARCH=1. "
"`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 " "`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 "
"VLLM_ALLOW_DEPRECATED_BEAM_SEARCH=1 pytest " "pytest tests/basic_correctness/test_preemption.py`")
"tests/basic_correctness/test_preemption.py`")
@pytest.fixture @pytest.fixture
@ -137,114 +135,6 @@ def test_preemption(
assert total_preemption == total_recorded_preemption assert total_preemption == total_recorded_preemption
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
@pytest.mark.parametrize("beam_width", [4])
def test_swap(
caplog_vllm,
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
beam_width: int,
worker_use_ray: bool,
) -> None:
"""Use beam search enables swapping."""
example_prompts = example_prompts[:1]
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width,
max_tokens)
with vllm_runner(
model,
dtype=dtype,
swap_space=10,
disable_log_stats=False,
worker_use_ray=worker_use_ray,
) as vllm_model:
vllm_outputs = vllm_model.generate_beam_search(example_prompts,
beam_width, max_tokens)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
< ARTIFICIAL_PREEMPTION_MAX_CNT)
total_preemption = (
vllm_model.model.llm_engine.scheduler[0].num_cumulative_preemption)
for i in range(len(example_prompts)):
hf_output_ids, _ = hf_outputs[i]
vllm_output_ids, _ = vllm_outputs[i]
assert len(hf_output_ids) == len(vllm_output_ids)
for j in range(len(hf_output_ids)):
assert hf_output_ids[j] == vllm_output_ids[j], (
f"Test{i} output{j}:\nHF: {hf_output_ids}\n"
f"vLLM: {vllm_output_ids}")
assert ("is preempted by PreemptionMode.SWAP mode because there "
"is not enough KV cache space." in caplog_vllm.text)
# Ensure the count bucket of request-level histogram metrics matches
# the number of requests as a simple sanity check to ensure metrics are
# generated
preemption_metrics = None
for m in REGISTRY.collect():
if m.name == "vllm:num_preemptions":
preemption_metrics = m
assert preemption_metrics is not None
total_recorded_preemption = 0
for sample in preemption_metrics.samples:
total_recorded_preemption += sample.value
assert total_preemption == total_recorded_preemption
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96])
@pytest.mark.parametrize("beam_width", [4])
@pytest.mark.parametrize("use_v2_block_manager", [True, False])
def test_swap_infeasible(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
beam_width: int,
worker_use_ray: bool,
use_v2_block_manager: bool,
) -> None:
"""Verify infeasible swap request will be ignored."""
BLOCK_SIZE = 16
prefill_blocks = 2
decode_blocks = max_tokens // BLOCK_SIZE
example_prompts = example_prompts[:1]
with vllm_runner(
model,
dtype=dtype,
swap_space=10,
block_size=BLOCK_SIZE,
# Since beam search have more than 1 sequence, prefill +
# decode blocks are not enough to finish.
num_gpu_blocks_override=prefill_blocks + decode_blocks,
max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE,
worker_use_ray=worker_use_ray,
use_v2_block_manager=use_v2_block_manager,
) as vllm_model:
sampling_params = SamplingParams(n=beam_width,
use_beam_search=True,
temperature=0.0,
max_tokens=max_tokens,
ignore_eos=True)
req_outputs = vllm_model.model.generate(
example_prompts,
sampling_params=sampling_params,
)
assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt
< ARTIFICIAL_PREEMPTION_MAX_CNT)
# Verify the request is ignored and not hang.
assert req_outputs[0].outputs[0].finish_reason == "length"
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [96]) @pytest.mark.parametrize("max_tokens", [96])

View File

@ -782,7 +782,6 @@ class VllmRunner:
List[TokensTextLogprobsPromptLogprobs]]: List[TokensTextLogprobsPromptLogprobs]]:
greedy_logprobs_params = SamplingParams( greedy_logprobs_params = SamplingParams(
temperature=0.0, temperature=0.0,
use_beam_search=False,
max_tokens=max_tokens, max_tokens=max_tokens,
logprobs=num_logprobs, logprobs=num_logprobs,
prompt_logprobs=(num_prompt_logprobs), prompt_logprobs=(num_prompt_logprobs),
@ -795,19 +794,6 @@ class VllmRunner:
encoder_decoder_prompts, greedy_logprobs_params) encoder_decoder_prompts, greedy_logprobs_params)
def generate_beam_search( def generate_beam_search(
self,
prompts: List[str],
beam_width: int,
max_tokens: int,
) -> List[Tuple[List[List[int]], List[str]]]:
beam_search_params = SamplingParams(n=beam_width,
use_beam_search=True,
temperature=0.0,
max_tokens=max_tokens)
outputs = self.generate(prompts, beam_search_params)
return outputs
def generate_beam_search_new(
self, self,
prompts: Union[List[str], List[List[int]]], prompts: Union[List[str], List[List[int]]],
beam_width: int, beam_width: int,

View File

@ -85,73 +85,6 @@ def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator,
assert baseline_token_ids == test_token_ids assert baseline_token_ids == test_token_ids
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Use a small model for a fast test.
"model": "facebook/opt-125m",
# skip cuda graph creation for fast test.
"enforce_eager": True,
# Use a large block size to trigger more copy-on-writes.
"block_size": 32,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{
"use_v2_block_manager": False
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
"use_v2_block_manager": True,
"preemption_mode": "swap"
}, {
"use_v2_block_manager": True,
"preemption_mode": "recompute"
}])
@pytest.mark.parametrize("batch_size", [10])
@pytest.mark.parametrize("seed", [1])
def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator,
test_llm_generator, batch_size):
"""Verify beam search equality with block manager v1 and v2.
This requires copy-on-writes; if the v1 and v2 output is the same, then
we have some confidence cow is working.
"""
output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
use_beam_search=True,
best_of=2,
)
print('Getting token ids from block manager v1')
baseline_token_ids = get_token_ids_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)
print('Getting token ids from block manager v2')
test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
prompts, sampling_params)
for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
test_token_ids):
assert expected_token_ids == actual_token_ids
assert baseline_token_ids == test_token_ids
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
[{ [{

View File

@ -13,7 +13,6 @@ def create_dummy_prompt(
prompt_length: int, prompt_length: int,
block_size: Optional[int] = None, block_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
use_beam_search: bool = False,
best_of: int = 1, best_of: int = 1,
prompt_tokens: Optional[List[int]] = None, prompt_tokens: Optional[List[int]] = None,
min_tokens: int = 0, min_tokens: int = 0,
@ -37,7 +36,6 @@ def create_dummy_prompt(
seqs=[prompt], seqs=[prompt],
arrival_time=time.time(), arrival_time=time.time(),
sampling_params=SamplingParams( sampling_params=SamplingParams(
use_beam_search=use_beam_search,
best_of=best_of, best_of=best_of,
max_tokens=max_tokens, max_tokens=max_tokens,
min_tokens=min_tokens), min_tokens=min_tokens),
@ -52,7 +50,6 @@ def create_dummy_prompt_encoder_decoder(
encoder_prompt_length: int, encoder_prompt_length: int,
block_size: Optional[int] = None, block_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
use_beam_search: bool = False,
best_of: int = 1, best_of: int = 1,
) -> Tuple[Sequence, Sequence, SequenceGroup]: ) -> Tuple[Sequence, Sequence, SequenceGroup]:
if not block_size: if not block_size:
@ -85,9 +82,7 @@ def create_dummy_prompt_encoder_decoder(
from_decoder_prompt=False) from_decoder_prompt=False)
seq_group = SequenceGroup(request_id=request_id, seq_group = SequenceGroup(request_id=request_id,
seqs=[decoder_prompt], seqs=[decoder_prompt],
sampling_params=SamplingParams( sampling_params=SamplingParams(best_of=best_of),
use_beam_search=use_beam_search,
best_of=best_of),
arrival_time=time.time(), arrival_time=time.time(),
lora_request=lora_request, lora_request=lora_request,
encoder_seq=encoder_prompt) encoder_seq=encoder_prompt)

View File

@ -33,8 +33,8 @@ def test_beam_search_single_input(
max_tokens) max_tokens)
with vllm_runner(model, dtype=dtype) as vllm_model: with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_beam_search_new( vllm_outputs = vllm_model.generate_beam_search(example_prompts,
example_prompts, beam_width, max_tokens) beam_width, max_tokens)
for i in range(len(example_prompts)): for i in range(len(example_prompts)):
hf_output_ids, hf_output_texts = hf_outputs[i] hf_output_ids, hf_output_texts = hf_outputs[i]

View File

@ -159,26 +159,6 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str):
assert first_sampler_output == second_sampler_output assert first_sampler_output == second_sampler_output
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_beam(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
_, fake_logits, sampler = _prepare_test(batch_size)
sampling_params = SamplingParams(
temperature=0,
best_of=2,
use_beam_search=True,
)
_do_sample(batch_size, fake_logits, sampler, sampling_params, device)
# no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
# when handling an all-beam search case.
@pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_min_tokens_penalty(seed: int, device: str): def test_sampler_min_tokens_penalty(seed: int, device: str):
@ -479,7 +459,7 @@ def test_sampler_mixed(seed: int, device: str):
seq_lens: List[int] = [] seq_lens: List[int] = []
for i in range(batch_size): for i in range(batch_size):
expected: Optional[List[int]] = None expected: Optional[List[int]] = None
sampling_type = random.randint(0, 3) sampling_type = random.randint(0, 2)
if sampling_type == 0: if sampling_type == 0:
sampling_params = SamplingParams(temperature=0) sampling_params = SamplingParams(temperature=0)
expected = [int(torch.argmax(fake_logits[i], dim=-1).item())] expected = [int(torch.argmax(fake_logits[i], dim=-1).item())]
@ -498,10 +478,7 @@ def test_sampler_mixed(seed: int, device: str):
for idx in range(n): for idx in range(n):
fake_logits[i, i + idx] = 1e2 fake_logits[i, i + idx] = 1e2
expected = list(range(i, i + n)) expected = list(range(i, i + n))
else:
sampling_params = SamplingParams(temperature=0,
use_beam_search=True,
best_of=2)
expected_tokens.append(expected) expected_tokens.append(expected)
seq_group_metadata_list.append( seq_group_metadata_list.append(
SequenceGroupMetadata( SequenceGroupMetadata(
@ -530,9 +507,6 @@ def test_sampler_mixed(seed: int, device: str):
zip(sampler_output, seq_group_metadata_list)): zip(sampler_output, seq_group_metadata_list)):
assert metadata.sampling_params is not None assert metadata.sampling_params is not None
if metadata.sampling_params.use_beam_search:
continue
if (metadata.sampling_params.seed is not None if (metadata.sampling_params.seed is not None
and expected_tokens[i] is None): and expected_tokens[i] is None):
# Record seeded random result to compare with results of # Record seeded random result to compare with results of

View File

@ -1202,9 +1202,9 @@ class Scheduler:
seq_group=seq_group, num_lookahead_slots=num_lookahead_slots) seq_group=seq_group, num_lookahead_slots=num_lookahead_slots)
def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool: def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool:
# TODO: does it work with parallel sampling?
no_beam_search = seq_group.sampling_params is None or ( no_beam_search = seq_group.sampling_params is None or (
seq_group.sampling_params.best_of == 1 seq_group.sampling_params.best_of == 1)
and not seq_group.sampling_params.use_beam_search)
return no_beam_search return no_beam_search
def schedule( def schedule(

View File

@ -33,7 +33,7 @@ from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import (collect_from_async_generator, deprecate_kwargs, from vllm.utils import (collect_from_async_generator, deprecate_kwargs,
random_uuid, weak_bind) get_beam_search_score, random_uuid, weak_bind)
logger = init_logger(__name__) logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
@ -1050,6 +1050,12 @@ class AsyncLLMEngine:
max_tokens = params.max_tokens max_tokens = params.max_tokens
ignore_eos = params.ignore_eos ignore_eos = params.ignore_eos
temperature = params.temperature temperature = params.temperature
length_penalty = params.length_penalty
def sort_beams_key(x: BeamSearchSequence) -> float:
return get_beam_search_score(x.tokens, x.cum_logprob,
tokenizer.eos_token_id,
length_penalty)
tokenizer = await self.get_tokenizer() tokenizer = await self.get_tokenizer()
tokenizedPrompt = prompt if isinstance( tokenizedPrompt = prompt if isinstance(
@ -1103,15 +1109,11 @@ class AsyncLLMEngine:
else: else:
new_beams.append(new_beam) new_beams.append(new_beam)
sorted_beams = sorted(new_beams, sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True)
key=lambda x: x.cum_logprob,
reverse=True)
all_beams = sorted_beams[:beam_width] all_beams = sorted_beams[:beam_width]
completed.extend(all_beams) completed.extend(all_beams)
sorted_completed = sorted(completed, sorted_completed = sorted(completed, key=sort_beams_key, reverse=True)
key=lambda x: x.cum_logprob,
reverse=True)
best_beams = sorted_completed[:beam_width] best_beams = sorted_completed[:beam_width]
for beam in best_beams: for beam in best_beams:

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Tuple
from vllm.config import SchedulerConfig from vllm.config import SchedulerConfig
from vllm.core.scheduler import Scheduler from vllm.core.scheduler import Scheduler
@ -6,7 +6,6 @@ from vllm.engine.output_processor.interfaces import (
SequenceGroupOutputProcessor) SequenceGroupOutputProcessor)
from vllm.engine.output_processor.stop_checker import StopChecker from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput,
SequenceOutput, SequenceStatus) SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.detokenizer import Detokenizer
@ -113,7 +112,7 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
outputs: SequenceGroupOutput, outputs: SequenceGroupOutput,
is_async: bool) -> None: is_async: bool) -> None:
sampling_params = seq_group.sampling_params sampling_params = seq_group.sampling_params
if sampling_params.best_of == 1 and not sampling_params.use_beam_search: if sampling_params.best_of == 1:
# only have one output sample # only have one output sample
sample = outputs.samples[0] sample = outputs.samples[0]
# only have one sequence # only have one sequence
@ -142,7 +141,6 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
# Process samples # Process samples
samples = outputs.samples samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
existing_finished_seqs = seq_group.get_finished_seqs()
parent_child_dict: Dict[int, List[SequenceOutput]] = { parent_child_dict: Dict[int, List[SequenceOutput]] = {
parent_seq.seq_id: [] parent_seq.seq_id: []
for parent_seq in parent_seqs for parent_seq in parent_seqs
@ -197,106 +195,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
lora_req=seq_group.lora_request, lora_req=seq_group.lora_request,
) )
# Non-beam search case
if not sampling_params.use_beam_search:
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for seq, parent in child_seqs:
if seq is not parent:
seq_group.add(seq)
if not seq.is_finished():
for scheduler in self.scheduler:
scheduler.fork_seq(parent, seq)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
# NOTE: we need to fork the new sequences before freeing the
# old sequences.
for seq, parent in child_seqs:
if seq is parent and seq.is_finished():
for scheduler in self.scheduler:
scheduler.free_seq(seq)
return
# Beam search case
# Select the child sequences to keep in the sequence group.
selected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = []
unselected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = []
beam_width = sampling_params.best_of
length_penalty = sampling_params.length_penalty
# Select the newly finished sequences with the highest scores
# to replace existing finished sequences.
# Tuple of (seq, parent, is_new)
existing_finished_seqs = [(seq, None, False)
for seq in existing_finished_seqs]
new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs
if seq.is_finished()]
all_finished_seqs = existing_finished_seqs + new_finished_seqs
# Sort the finished sequences by their scores.
all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
reverse=True)
for seq, parent, is_new in all_finished_seqs[:beam_width]:
if is_new:
# A newly generated child sequence finishes and has a high
# score, so we will add it into the sequence group.
selected_child_seqs.append((seq, parent))
for seq, parent, is_new in all_finished_seqs[beam_width:]:
if is_new:
# A newly generated child sequence finishes but has a low
# score, so we will not add it into the sequence group.
# Additionally, if this sequence is a continuation of a
# parent sequence, we will need remove the parent sequence
# from the sequence group.
unselected_child_seqs.append((seq, parent))
else:
# An existing finished sequence has a low score, so we will
# remove it from the sequence group.
seq_group.remove(seq.seq_id)
# select the top beam_width sequences from the running
# sequences for the next iteration to continue the beam
# search.
running_child_seqs = [(seq, parent) for seq, parent in child_seqs
if not seq.is_finished()]
# Sort the running sequences by their scores.
running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score(
length_penalty=length_penalty, eos_token_id=x[0].eos_token_id),
reverse=True)
# Check if we can stop the beam search.
if len(running_child_seqs) == 0:
# No running sequences, stop the beam search.
stop_beam_search = True
elif len(all_finished_seqs) < beam_width:
# Not enough finished sequences, continue the beam search.
stop_beam_search = False
else:
# Check the early stopping criteria
best_running_seq = running_child_seqs[0][0]
current_worst_seq = all_finished_seqs[beam_width - 1][0]
stop_beam_search = self._check_beam_search_early_stopping(
sampling_params.early_stopping, sampling_params,
best_running_seq, current_worst_seq)
if stop_beam_search:
# Stop the beam search and remove all the running sequences from
# the sequence group.
unselected_child_seqs.extend(running_child_seqs)
else:
# Continue the beam search and select the top beam_width sequences
# to continue the beam search.
selected_child_seqs.extend(running_child_seqs[:beam_width])
# The remaining running sequences will not be used in the next
# iteration. Again, if these sequences are continuations of
# parent sequences, we will need to remove the parent sequences
# from the sequence group.
unselected_child_seqs.extend(running_child_seqs[beam_width:])
# For newly created child sequences, add them to the sequence group # For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished. # and fork them in block manager if they are not finished.
for seq, parent in selected_child_seqs: for seq, parent in child_seqs:
if seq is not parent: if seq is not parent:
seq_group.add(seq) seq_group.add(seq)
if not seq.is_finished(): if not seq.is_finished():
@ -305,61 +206,10 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
# Free the finished and selected parent sequences' memory in block # Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output. # manager. Keep them in the sequence group as candidate output.
for seq, parent in selected_child_seqs: # NOTE: we need to fork the new sequences before freeing the
# old sequences.
for seq, parent in child_seqs:
if seq is parent and seq.is_finished(): if seq is parent and seq.is_finished():
for scheduler in self.scheduler: for scheduler in self.scheduler:
scheduler.free_seq(seq) scheduler.free_seq(seq)
return
# Remove the unselected parent sequences from the sequence group and
# free their memory in block manager.
for seq, parent in unselected_child_seqs:
if seq is parent:
# Remove the parent sequence if it is not selected for next
# iteration
seq_group.remove(seq.seq_id)
for scheduler in self.scheduler:
scheduler.free_seq(seq)
def _check_beam_search_early_stopping(
self,
early_stopping: Union[bool, str],
sampling_params: SamplingParams,
best_running_seq: Sequence,
current_worst_seq: Sequence,
) -> bool:
assert sampling_params.use_beam_search
length_penalty = sampling_params.length_penalty
if early_stopping is True:
return True
current_worst_score = current_worst_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=current_worst_seq.eos_token_id)
if early_stopping is False:
highest_attainable_score = best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=best_running_seq.eos_token_id)
else:
assert early_stopping == "never"
if length_penalty > 0.0:
# If length_penalty > 0.0, beam search will prefer longer
# sequences. The highest attainable score calculation is
# based on the longest possible sequence length in this case.
max_possible_length = max(
best_running_seq.get_prompt_len() +
sampling_params.max_tokens,
self.scheduler_config.max_model_len)
highest_attainable_score = (
best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=best_running_seq.eos_token_id,
seq_len=max_possible_length))
else:
# Otherwise, beam search will prefer shorter sequences. The
# highest attainable score calculation is based on the current
# sequence length.
highest_attainable_score = (
best_running_seq.get_beam_search_score(
length_penalty=length_penalty,
eos_token_id=best_running_seq.eos_token_id))
return current_worst_score >= highest_attainable_score

View File

@ -28,7 +28,8 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer) get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_kwargs, is_list_of from vllm.utils import (Counter, deprecate_kwargs, get_beam_search_score,
is_list_of)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -404,6 +405,12 @@ class LLM:
max_tokens = params.max_tokens max_tokens = params.max_tokens
temperature = params.temperature temperature = params.temperature
ignore_eos = params.ignore_eos ignore_eos = params.ignore_eos
length_penalty = params.length_penalty
def sort_beams_key(x: BeamSearchSequence) -> float:
return get_beam_search_score(x.tokens, x.cum_logprob,
tokenizer.eos_token_id,
length_penalty)
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
# generate 2 * beam_width candidates at each step # generate 2 * beam_width candidates at each step
@ -466,7 +473,7 @@ class LLM:
else: else:
instance_new_beams.append(new_beam) instance_new_beams.append(new_beam)
sorted_beams = sorted(instance_new_beams, sorted_beams = sorted(instance_new_beams,
key=lambda x: x.cum_logprob, key=sort_beams_key,
reverse=True) reverse=True)
instance.beams = sorted_beams[:beam_width] instance.beams = sorted_beams[:beam_width]
@ -474,7 +481,7 @@ class LLM:
for instance in instances: for instance in instances:
instance.completed.extend(instance.beams) instance.completed.extend(instance.beams)
sorted_completed = sorted(instance.completed, sorted_completed = sorted(instance.completed,
key=lambda x: x.cum_logprob, key=sort_beams_key,
reverse=True) reverse=True)
best_beams = sorted_completed[:beam_width] best_beams = sorted_completed[:beam_width]

View File

@ -184,7 +184,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
min_p: float = 0.0 min_p: float = 0.0
repetition_penalty: float = 1.0 repetition_penalty: float = 1.0
length_penalty: float = 1.0 length_penalty: float = 1.0
early_stopping: bool = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list) stop_token_ids: Optional[List[int]] = Field(default_factory=list)
include_stop_str_in_output: bool = False include_stop_str_in_output: bool = False
ignore_eos: bool = False ignore_eos: bool = False
@ -302,6 +301,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
max_tokens=max_tokens, max_tokens=max_tokens,
ignore_eos=self.ignore_eos, ignore_eos=self.ignore_eos,
temperature=temperature, temperature=temperature,
length_penalty=self.length_penalty,
) )
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
@ -345,12 +345,9 @@ class ChatCompletionRequest(OpenAIBaseModel):
ignore_eos=self.ignore_eos, ignore_eos=self.ignore_eos,
max_tokens=max_tokens, max_tokens=max_tokens,
min_tokens=self.min_tokens, min_tokens=self.min_tokens,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
skip_special_tokens=self.skip_special_tokens, skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output, include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \ output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY, else RequestOutputKind.FINAL_ONLY,
@ -518,7 +515,6 @@ class CompletionRequest(OpenAIBaseModel):
min_p: float = 0.0 min_p: float = 0.0
repetition_penalty: float = 1.0 repetition_penalty: float = 1.0
length_penalty: float = 1.0 length_penalty: float = 1.0
early_stopping: bool = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list) stop_token_ids: Optional[List[int]] = Field(default_factory=list)
include_stop_str_in_output: bool = False include_stop_str_in_output: bool = False
ignore_eos: bool = False ignore_eos: bool = False
@ -597,6 +593,7 @@ class CompletionRequest(OpenAIBaseModel):
max_tokens=max_tokens, max_tokens=max_tokens,
ignore_eos=self.ignore_eos, ignore_eos=self.ignore_eos,
temperature=temperature, temperature=temperature,
length_penalty=self.length_penalty,
) )
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
@ -641,13 +638,10 @@ class CompletionRequest(OpenAIBaseModel):
ignore_eos=self.ignore_eos, ignore_eos=self.ignore_eos,
max_tokens=max_tokens if not echo_without_generation else 1, max_tokens=max_tokens if not echo_without_generation else 1,
min_tokens=self.min_tokens, min_tokens=self.min_tokens,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
prompt_logprobs=prompt_logprobs, prompt_logprobs=prompt_logprobs,
skip_special_tokens=self.skip_special_tokens, skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output, include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
truncate_prompt_tokens=self.truncate_prompt_tokens, truncate_prompt_tokens=self.truncate_prompt_tokens,
output_kind=RequestOutputKind.DELTA if self.stream \ output_kind=RequestOutputKind.DELTA if self.stream \
else RequestOutputKind.FINAL_ONLY, else RequestOutputKind.FINAL_ONLY,

View File

@ -63,7 +63,6 @@ if TYPE_CHECKING:
VLLM_TORCH_PROFILER_DIR: Optional[str] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None
VLLM_USE_TRITON_AWQ: bool = False VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_ALLOW_DEPRECATED_BEAM_SEARCH: bool = False
VLLM_SKIP_P2P_CHECK: bool = False VLLM_SKIP_P2P_CHECK: bool = False
@ -198,10 +197,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
("true", "1")), ("true", "1")),
# If set, allowing the use of deprecated beam search implementation
"VLLM_ALLOW_DEPRECATED_BEAM_SEARCH":
lambda: os.environ.get("VLLM_ALLOW_DEPRECATED_BEAM_SEARCH", "0") == "1",
# Internal flag to enable Dynamo graph capture # Internal flag to enable Dynamo graph capture
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE": "VLLM_TEST_DYNAMO_GRAPH_CAPTURE":
lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")), lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")),

View File

@ -947,8 +947,6 @@ def get_logprobs(
# largest num logprobs in this API. If every logprobs is None, it will be # largest num logprobs in this API. If every logprobs is None, it will be
# set to -1. # set to -1.
largest_num_logprobs = -1 largest_num_logprobs = -1
# If beam search is enabled.
use_beam_search = False
# Select indices to compute logprob from, ranks of token ids, and the top # Select indices to compute logprob from, ranks of token ids, and the top
# k token ids from logprobs. # k token ids from logprobs.
@ -981,8 +979,6 @@ def get_logprobs(
largest_num_logprobs = max(largest_num_logprobs, largest_num_logprobs = max(largest_num_logprobs,
sampling_params.logprobs) sampling_params.logprobs)
use_beam_search = use_beam_search or sampling_params.use_beam_search
assert len(next_token_ids) == len(query_indices) assert len(next_token_ids) == len(query_indices)
if len(query_indices) == 0: if len(query_indices) == 0:
@ -995,7 +991,7 @@ def get_logprobs(
# If largest_num_logprobs == -1, i.e. no logprobs are requested, we can # If largest_num_logprobs == -1, i.e. no logprobs are requested, we can
# skip the whole logprob calculation. # skip the whole logprob calculation.
if largest_num_logprobs >= 0 or use_beam_search: if largest_num_logprobs >= 0:
query_indices_gpu = torch.tensor(query_indices, device=logprobs.device) query_indices_gpu = torch.tensor(query_indices, device=logprobs.device)
next_token_ids_gpu = torch.tensor(next_token_ids, next_token_ids_gpu = torch.tensor(next_token_ids,
device=logprobs.device) device=logprobs.device)
@ -1121,13 +1117,12 @@ def _get_sampled_logprob_if_needed(
"""Compute the sample logprob if needed.""" """Compute the sample logprob if needed."""
seq_ids = seq_group.seq_ids seq_ids = seq_group.seq_ids
num_logprobs = seq_group.sampling_params.logprobs num_logprobs = seq_group.sampling_params.logprobs
use_beam_search = seq_group.sampling_params.use_beam_search
sampled_logprobs: SampleLogprobs = [] sampled_logprobs: SampleLogprobs = []
next_token_ids, parent_seq_ids = sample_result next_token_ids, parent_seq_ids = sample_result
if seq_group.do_sample: if seq_group.do_sample:
assert len(next_token_ids) > 0 assert len(next_token_ids) > 0
if num_logprobs is None and not use_beam_search: if num_logprobs is None:
for next_token_id in next_token_ids: for next_token_id in next_token_ids:
# Use a dummy logprob # Use a dummy logprob
sampled_logprobs.append({next_token_id: Logprob(inf)}) sampled_logprobs.append({next_token_id: Logprob(inf)})

View File

@ -142,11 +142,7 @@ class RequestOutput:
else: else:
# Get the top-n sequences. # Get the top-n sequences.
n = sampling_params.n n = sampling_params.n
if sampling_params.use_beam_search: sorting_key = lambda seq: seq.get_cumulative_logprob()
sorting_key = lambda seq: seq.get_beam_search_score(
sampling_params.length_penalty)
else:
sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
top_n_seqs = sorted_seqs[:n] top_n_seqs = sorted_seqs[:n]

View File

@ -10,7 +10,6 @@ import torch
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import Annotated from typing_extensions import Annotated
import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
@ -23,7 +22,6 @@ class SamplingType(IntEnum):
GREEDY = 0 GREEDY = 0
RANDOM = 1 RANDOM = 1
RANDOM_SEED = 2 RANDOM_SEED = 2
BEAM = 3
LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor], LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor],
@ -134,16 +132,6 @@ class SamplingParams(
considered, relative to the probability of the most likely token. considered, relative to the probability of the most likely token.
Must be in [0, 1]. Set to 0 to disable this. Must be in [0, 1]. Set to 0 to disable this.
seed: Random seed to use for the generation. seed: Random seed to use for the generation.
use_beam_search: Whether to use beam search instead of sampling.
length_penalty: Float that penalizes sequences based on their length.
Used in beam search.
early_stopping: Controls the stopping condition for beam search. It
accepts the following values: `True`, where the generation stops as
soon as there are `best_of` complete candidates; `False`, where an
heuristic is applied and the generation stops when is it very
unlikely to find better candidates; `"never"`, where the beam search
procedure only stops when there cannot be better candidates
(canonical beam search algorithm).
stop: List of strings that stop the generation when they are generated. stop: List of strings that stop the generation when they are generated.
The returned output will not contain the stop strings. The returned output will not contain the stop strings.
stop_token_ids: List of tokens that stop the generation when they are stop_token_ids: List of tokens that stop the generation when they are
@ -193,9 +181,6 @@ class SamplingParams(
top_k: int = -1 top_k: int = -1
min_p: float = 0.0 min_p: float = 0.0
seed: Optional[int] = None seed: Optional[int] = None
use_beam_search: bool = False
length_penalty: float = 1.0
early_stopping: Union[bool, str] = False
stop: Optional[Union[str, List[str]]] = None stop: Optional[Union[str, List[str]]] = None
stop_token_ids: Optional[List[int]] = None stop_token_ids: Optional[List[int]] = None
ignore_eos: bool = False ignore_eos: bool = False
@ -238,9 +223,6 @@ class SamplingParams(
top_k: int = -1, top_k: int = -1,
min_p: float = 0.0, min_p: float = 0.0,
seed: Optional[int] = None, seed: Optional[int] = None,
use_beam_search: bool = False,
length_penalty: float = 1.0,
early_stopping: Union[bool, str] = False,
stop: Optional[Union[str, List[str]]] = None, stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None, stop_token_ids: Optional[List[int]] = None,
include_stop_str_in_output: bool = False, include_stop_str_in_output: bool = False,
@ -280,9 +262,6 @@ class SamplingParams(
top_k=top_k, top_k=top_k,
min_p=min_p, min_p=min_p,
seed=seed, seed=seed,
use_beam_search=use_beam_search,
length_penalty=length_penalty,
early_stopping=early_stopping,
stop=stop, stop=stop,
stop_token_ids=stop_token_ids, stop_token_ids=stop_token_ids,
include_stop_str_in_output=include_stop_str_in_output, include_stop_str_in_output=include_stop_str_in_output,
@ -334,20 +313,13 @@ class SamplingParams(
self.output_text_buffer_length = max(len(s) for s in self.stop) - 1 self.output_text_buffer_length = max(len(s) for s in self.stop) - 1
self._verify_args() self._verify_args()
if self.use_beam_search:
if not envs.VLLM_ALLOW_DEPRECATED_BEAM_SEARCH: if self.temperature < _SAMPLING_EPS:
raise ValueError( # Zero temperature means greedy sampling.
"Using beam search as a sampling parameter is deprecated, and will be removed in the future release. Please use the `vllm.LLM.use_beam_search` method for dedicated beam search instead, or set the environment variable `VLLM_ALLOW_DEPRECATED_BEAM_SEARCH=1` to suppress this error. For more details, see https://github.com/vllm-project/vllm/issues/8306 ." # noqa self.top_p = 1.0
) self.top_k = -1
self._verify_beam_search() self.min_p = 0.0
else: self._verify_greedy_sampling()
self._verify_non_beam_search()
if self.temperature < _SAMPLING_EPS:
# Zero temperature means greedy sampling.
self.top_p = 1.0
self.top_k = -1
self.min_p = 0.0
self._verify_greedy_sampling()
# eos_token_id is added to this by the engine # eos_token_id is added to this by the engine
self._all_stop_token_ids = set(self.stop_token_ids) self._all_stop_token_ids = set(self.stop_token_ids)
@ -417,31 +389,6 @@ class SamplingParams(
RequestOutputKind.DELTA): RequestOutputKind.DELTA):
raise ValueError("best_of must equal n to use output_kind=DELTA") raise ValueError("best_of must equal n to use output_kind=DELTA")
def _verify_beam_search(self) -> None:
if self.best_of == 1:
raise ValueError("best_of must be greater than 1 when using beam "
f"search. Got {self.best_of}.")
if self.temperature > _SAMPLING_EPS:
raise ValueError("temperature must be 0 when using beam search.")
if self.top_p < 1.0 - _SAMPLING_EPS:
raise ValueError("top_p must be 1 when using beam search.")
if self.top_k != -1:
raise ValueError("top_k must be -1 when using beam search.")
if self.early_stopping not in [True, False, "never"]:
raise ValueError(
f"early_stopping must be True, False, or 'never', "
f"got {self.early_stopping}.")
def _verify_non_beam_search(self) -> None:
if self.early_stopping is not False:
raise ValueError("early_stopping is not effective and must be "
"False when not using beam search.")
if (self.length_penalty < 1.0 - _SAMPLING_EPS
or self.length_penalty > 1.0 + _SAMPLING_EPS):
raise ValueError(
"length_penalty is not effective and must be the "
"default value of 1.0 when not using beam search.")
def _verify_greedy_sampling(self) -> None: def _verify_greedy_sampling(self) -> None:
assert isinstance(self.best_of, int) assert isinstance(self.best_of, int)
if self.best_of > 1: if self.best_of > 1:
@ -476,8 +423,6 @@ class SamplingParams(
@cached_property @cached_property
def sampling_type(self) -> SamplingType: def sampling_type(self) -> SamplingType:
if self.use_beam_search:
return SamplingType.BEAM
if self.temperature < _SAMPLING_EPS: if self.temperature < _SAMPLING_EPS:
return SamplingType.GREEDY return SamplingType.GREEDY
if self.seed is not None: if self.seed is not None:
@ -514,9 +459,6 @@ class SamplingParams(
f"top_k={self.top_k}, " f"top_k={self.top_k}, "
f"min_p={self.min_p}, " f"min_p={self.min_p}, "
f"seed={self.seed}, " f"seed={self.seed}, "
f"use_beam_search={self.use_beam_search}, "
f"length_penalty={self.length_penalty}, "
f"early_stopping={self.early_stopping}, "
f"stop={self.stop}, " f"stop={self.stop}, "
f"stop_token_ids={self.stop_token_ids}, " f"stop_token_ids={self.stop_token_ids}, "
f"include_stop_str_in_output={self.include_stop_str_in_output}, " f"include_stop_str_in_output={self.include_stop_str_in_output}, "
@ -542,3 +484,4 @@ class BeamSearchParams(
max_tokens: int max_tokens: int
ignore_eos: bool = False ignore_eos: bool = False
temperature: float = 0.0 temperature: float = 0.0
length_penalty: float = 1.0

View File

@ -577,25 +577,6 @@ class Sequence:
def get_cumulative_logprob(self) -> float: def get_cumulative_logprob(self) -> float:
return self.data.cumulative_logprob return self.data.cumulative_logprob
def get_beam_search_score(self,
length_penalty: float = 1.0,
seq_len: Optional[int] = None,
eos_token_id: Optional[int] = None) -> float:
"""Calculate the beam search score with length penalty.
Adapted from
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
"""
if seq_len is None:
seq_len = self.get_len()
# NOTE: HF implementation does not count the EOS token
# towards the length, we align with that here for testing.
if (eos_token_id is not None
and self.get_last_token_id() == eos_token_id):
seq_len -= 1
return self.get_cumulative_logprob() / (seq_len**length_penalty)
def is_finished(self) -> bool: def is_finished(self) -> bool:
return SequenceStatus.is_finished(self.status) return SequenceStatus.is_finished(self.status)
@ -809,25 +790,18 @@ class SequenceGroup:
def get_max_num_running_seqs(self) -> int: def get_max_num_running_seqs(self) -> int:
"""The maximum number of sequences running in parallel in the remaining """The maximum number of sequences running in parallel in the remaining
lifetime of the request.""" lifetime of the request."""
if self.sampling_params and self.sampling_params.use_beam_search: if self.sampling_params:
# For beam search, maximally there will always be `best_of` beam
# candidates running in the future.
best_of = self.sampling_params.best_of best_of = self.sampling_params.best_of
assert isinstance(best_of, int) assert isinstance(best_of, int)
return best_of if best_of > self.num_seqs():
else: # At prompt stage, the sequence group is not yet filled up
if self.sampling_params: # and only have one sequence running. However, in the
best_of = self.sampling_params.best_of # generation stage, we will have `best_of` sequences
assert isinstance(best_of, int) # running.
if best_of > self.num_seqs(): return best_of
# At prompt stage, the sequence group is not yet filled up # At sampling stages, return the number of actual sequences
# and only have one sequence running. However, in the # that are not finished yet.
# generation stage, we will have `best_of` sequences return self.num_unfinished_seqs()
# running.
return best_of
# At sampling stages, return the number of actual sequences
# that are not finished yet.
return self.num_unfinished_seqs()
def get_seqs( def get_seqs(
self, self,

View File

@ -1361,3 +1361,22 @@ class AtomicCounter:
@property @property
def value(self): def value(self):
return self._value return self._value
def get_beam_search_score(
tokens: List[int],
cumulative_logprob: float,
eos_token_id: int,
length_penalty: float = 1.0,
) -> float:
"""Calculate the beam search score with length penalty.
Adapted from
https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938
"""
seq_len = len(tokens)
if tokens[-1] == eos_token_id:
seq_len -= 1
return cumulative_logprob / (seq_len**length_penalty)

View File

@ -453,9 +453,6 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU " f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
"backend.") "backend.")
best_of.append(sampling_params.best_of) best_of.append(sampling_params.best_of)
if sampling_params.use_beam_search:
raise NotImplementedError(
"Beam search is not supported by the TPU backend.")
if sampling_params.logprobs is not None: if sampling_params.logprobs is not None:
raise NotImplementedError( raise NotImplementedError(
"logprobs is not currently supported by the TPU backend.") "logprobs is not currently supported by the TPU backend.")