[V1][Core] Support for Structured Outputs (#12388)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Signed-off-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
1e3598edeb
commit
80e9afb5bc
@ -204,6 +204,7 @@ steps:
|
|||||||
- VLLM_USE_V1=1 pytest -v -s v1/engine
|
- VLLM_USE_V1=1 pytest -v -s v1/engine
|
||||||
- VLLM_USE_V1=1 pytest -v -s v1/sample
|
- VLLM_USE_V1=1 pytest -v -s v1/sample
|
||||||
- VLLM_USE_V1=1 pytest -v -s v1/worker
|
- VLLM_USE_V1=1 pytest -v -s v1/worker
|
||||||
|
- VLLM_USE_V1=1 pytest -v -s v1/structured_output
|
||||||
- VLLM_USE_V1=1 pytest -v -s v1/test_stats.py
|
- VLLM_USE_V1=1 pytest -v -s v1/test_stats.py
|
||||||
- VLLM_USE_V1=1 pytest -v -s v1/test_utils.py
|
- VLLM_USE_V1=1 pytest -v -s v1/test_utils.py
|
||||||
# TODO: accuracy does not match, whether setting
|
# TODO: accuracy does not match, whether setting
|
||||||
|
2
.gitignore
vendored
2
.gitignore
vendored
@ -197,7 +197,7 @@ _build/
|
|||||||
hip_compat.h
|
hip_compat.h
|
||||||
|
|
||||||
# Benchmark dataset
|
# Benchmark dataset
|
||||||
benchmarks/*.json
|
benchmarks/**/*.json
|
||||||
|
|
||||||
# Linting
|
# Linting
|
||||||
actionlint
|
actionlint
|
||||||
|
@ -1,507 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
"""Benchmark guided decoding throughput."""
|
|
||||||
import argparse
|
|
||||||
import dataclasses
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
import time
|
|
||||||
|
|
||||||
import datasets
|
|
||||||
import pandas as pd
|
|
||||||
import uvloop
|
|
||||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
|
||||||
|
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
|
||||||
from vllm.entrypoints.openai.api_server import (
|
|
||||||
build_async_engine_client_from_engine_args)
|
|
||||||
from vllm.sampling_params import GuidedDecodingParams
|
|
||||||
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class SampleRequest:
|
|
||||||
"""A class representing a single inference request for benchmarking.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
prompt: The input text prompt for the model.
|
|
||||||
multi_modal_data: Optional dictionary containing multi-modal data (e.g.
|
|
||||||
images).
|
|
||||||
prompt_len: The length of the prompt in tokens.
|
|
||||||
expected_output_len: The expected length of the output in tokens.
|
|
||||||
"""
|
|
||||||
prompt: str
|
|
||||||
prompt_len: int
|
|
||||||
expected_output_len: int
|
|
||||||
schema: dict
|
|
||||||
structure_type: str = 'json'
|
|
||||||
completion: str = None
|
|
||||||
|
|
||||||
|
|
||||||
def run_vllm(requests: list[SampleRequest],
|
|
||||||
engine_args: EngineArgs,
|
|
||||||
n: int,
|
|
||||||
guided_decoding_rate: float = 1.0,
|
|
||||||
warmup: bool = False) -> float:
|
|
||||||
from vllm import LLM, SamplingParams
|
|
||||||
llm = LLM(**vars(engine_args))
|
|
||||||
assert all(
|
|
||||||
llm.llm_engine.model_config.max_model_len >= (
|
|
||||||
request.prompt_len + request.expected_output_len)
|
|
||||||
for request in requests), (
|
|
||||||
"Please ensure that max_model_len is greater than the sum of"
|
|
||||||
" prompt_len and expected_output_len for all requests.")
|
|
||||||
|
|
||||||
# Add the requests to the engine.
|
|
||||||
prompts: list[str] = []
|
|
||||||
sampling_params: list[SamplingParams] = []
|
|
||||||
# create a list containing random selected true or false
|
|
||||||
guided_decoding_req_idx = random.sample(
|
|
||||||
range(len(requests)), int(len(requests) * guided_decoding_rate))
|
|
||||||
|
|
||||||
if warmup:
|
|
||||||
print(">>>>> Running warmup prompt, for the first 5")
|
|
||||||
# We setup the first 5 requests to warmup FSM
|
|
||||||
# if using xgrammar dataset, we will skip warmup
|
|
||||||
warmup_requests = requests[:5]
|
|
||||||
for i, request in enumerate(warmup_requests):
|
|
||||||
prompts.append(request.prompt)
|
|
||||||
sampling_params.append(
|
|
||||||
SamplingParams(
|
|
||||||
n=n,
|
|
||||||
temperature=1.0,
|
|
||||||
top_p=1.0,
|
|
||||||
ignore_eos=True,
|
|
||||||
max_tokens=request.expected_output_len,
|
|
||||||
guided_decoding=GuidedDecodingParams(json=request.schema)
|
|
||||||
if guided_decoding_rate > 0 else None,
|
|
||||||
))
|
|
||||||
llm.generate(prompts, sampling_params, use_tqdm=False)
|
|
||||||
|
|
||||||
print(">>>>> Benchmark started...")
|
|
||||||
prompts = []
|
|
||||||
sampling_params = []
|
|
||||||
for i, request in enumerate(requests):
|
|
||||||
prompts.append(request.prompt)
|
|
||||||
sampling_params.append(
|
|
||||||
SamplingParams(
|
|
||||||
n=n,
|
|
||||||
temperature=1.0,
|
|
||||||
top_p=1.0,
|
|
||||||
ignore_eos=True,
|
|
||||||
max_tokens=request.expected_output_len,
|
|
||||||
guided_decoding=GuidedDecodingParams(
|
|
||||||
**{request.structure_type: request.schema})
|
|
||||||
if i in guided_decoding_req_idx else None,
|
|
||||||
))
|
|
||||||
|
|
||||||
start = time.perf_counter()
|
|
||||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=False)
|
|
||||||
ret = []
|
|
||||||
for output, request in zip(outputs, requests):
|
|
||||||
generated_text = output.outputs[0].text
|
|
||||||
ret.append({
|
|
||||||
"generated": generated_text,
|
|
||||||
"expected": request.completion
|
|
||||||
})
|
|
||||||
end = time.perf_counter()
|
|
||||||
return end - start, ret
|
|
||||||
|
|
||||||
|
|
||||||
async def run_vllm_async(
|
|
||||||
requests: list[SampleRequest],
|
|
||||||
engine_args: AsyncEngineArgs,
|
|
||||||
n: int,
|
|
||||||
guided_decoding_rate: float = 1.0,
|
|
||||||
warmup: bool = False,
|
|
||||||
disable_frontend_multiprocessing: bool = False) -> float:
|
|
||||||
from vllm import SamplingParams
|
|
||||||
|
|
||||||
async with build_async_engine_client_from_engine_args(
|
|
||||||
engine_args, disable_frontend_multiprocessing) as llm:
|
|
||||||
|
|
||||||
assert all(
|
|
||||||
llm.model_config.max_model_len >= (request.prompt_len +
|
|
||||||
request.expected_output_len)
|
|
||||||
for request in requests), (
|
|
||||||
"Please ensure that max_model_len is greater than the sum of"
|
|
||||||
" prompt_len and expected_output_len for all requests.")
|
|
||||||
|
|
||||||
# Add the requests to the engine.
|
|
||||||
prompts: list[str] = []
|
|
||||||
sampling_params: list[SamplingParams] = []
|
|
||||||
guided_decoding_req_idx = random.sample(
|
|
||||||
range(len(requests)), int(len(requests) * guided_decoding_rate))
|
|
||||||
|
|
||||||
if warmup:
|
|
||||||
print(">>>>>> Running warmup prompt, for the first 5")
|
|
||||||
# We setup the first 5 requests to warmup FSM
|
|
||||||
# if using xgrammar dataset, we will skip warmup
|
|
||||||
warmup_requests = requests[:5]
|
|
||||||
for i, request in enumerate(warmup_requests):
|
|
||||||
prompts.append(request.prompt)
|
|
||||||
sampling_params.append(
|
|
||||||
SamplingParams(
|
|
||||||
n=n,
|
|
||||||
temperature=1.0,
|
|
||||||
top_p=1.0,
|
|
||||||
ignore_eos=True,
|
|
||||||
max_tokens=request.expected_output_len,
|
|
||||||
guided_decoding=GuidedDecodingParams(
|
|
||||||
json=request.schema)
|
|
||||||
if guided_decoding_rate > 0 else None,
|
|
||||||
))
|
|
||||||
generators = []
|
|
||||||
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
|
|
||||||
generator = llm.generate(prompt, sp, request_id=f"test{i}")
|
|
||||||
generators.append(generator)
|
|
||||||
all_gens = merge_async_iterators(*generators)
|
|
||||||
async for i, res in all_gens:
|
|
||||||
pass
|
|
||||||
|
|
||||||
print(">>>>> Benchmark started...")
|
|
||||||
prompts = []
|
|
||||||
sampling_params = []
|
|
||||||
for i, request in enumerate(requests):
|
|
||||||
prompts.append(request.prompt)
|
|
||||||
sampling_params.append(
|
|
||||||
SamplingParams(
|
|
||||||
n=n,
|
|
||||||
temperature=1.0,
|
|
||||||
top_p=1.0,
|
|
||||||
ignore_eos=True,
|
|
||||||
max_tokens=request.expected_output_len,
|
|
||||||
guided_decoding=GuidedDecodingParams(json=request.schema)
|
|
||||||
if i in guided_decoding_req_idx else None,
|
|
||||||
))
|
|
||||||
|
|
||||||
generators = []
|
|
||||||
start_time = []
|
|
||||||
latencies = []
|
|
||||||
start = time.perf_counter()
|
|
||||||
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
|
|
||||||
generator = llm.generate(prompt, sp, request_id=f"test{i}")
|
|
||||||
generators.append(generator)
|
|
||||||
start_time.append(time.perf_counter())
|
|
||||||
latencies.append([])
|
|
||||||
all_gens = merge_async_iterators(*generators)
|
|
||||||
generated_texts = [''] * len(requests)
|
|
||||||
async for i, res in all_gens:
|
|
||||||
generated_texts[i] = res.outputs[0].text
|
|
||||||
lat = time.perf_counter() - start_time[i]
|
|
||||||
latencies[i].append(lat)
|
|
||||||
ret = [{
|
|
||||||
'generated': gt,
|
|
||||||
'expected': req.completion
|
|
||||||
} for gt, req in zip(generated_texts, requests)]
|
|
||||||
end = time.perf_counter()
|
|
||||||
first_latency = pd.Series([lat[0] * 1000 for lat in latencies])
|
|
||||||
next_latency = pd.Series([(lat[-1] - lat[0]) / len(lat[1:]) * 1000
|
|
||||||
for lat in latencies])
|
|
||||||
return end - start, ret, (first_latency, next_latency)
|
|
||||||
|
|
||||||
|
|
||||||
def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
|
||||||
args: argparse.Namespace) -> list[SampleRequest]:
|
|
||||||
if args.dataset == 'json':
|
|
||||||
if args.json_schema_path is None:
|
|
||||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
|
||||||
args.json_schema_path = os.path.join(dir_path,
|
|
||||||
"structured_schemas",
|
|
||||||
"structured_schema_1.json")
|
|
||||||
with open(args.json_schema_path) as f:
|
|
||||||
schema = json.load(f)
|
|
||||||
prompt = f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501
|
|
||||||
input_len = len(tokenizer(prompt).input_ids)
|
|
||||||
print(f"Input length of the prompt: {input_len} tokens")
|
|
||||||
requests = [
|
|
||||||
SampleRequest(prompt=prompt,
|
|
||||||
prompt_len=input_len,
|
|
||||||
expected_output_len=args.output_len,
|
|
||||||
schema=schema,
|
|
||||||
structure_type=args.structure_type)
|
|
||||||
for _ in range(args.num_prompts)
|
|
||||||
]
|
|
||||||
|
|
||||||
elif args.dataset == "grammar":
|
|
||||||
schema = """
|
|
||||||
?start: select_statement
|
|
||||||
|
|
||||||
?select_statement: "SELECT " column_list " FROM " table_name
|
|
||||||
|
|
||||||
?column_list: column_name ("," column_name)*
|
|
||||||
|
|
||||||
?table_name: identifier
|
|
||||||
|
|
||||||
?column_name: identifier
|
|
||||||
|
|
||||||
?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/
|
|
||||||
"""
|
|
||||||
prompt = "Generate an SQL query to show the 'username' \
|
|
||||||
and 'email' from the 'users' table."
|
|
||||||
|
|
||||||
input_len = len(tokenizer(prompt).input_ids)
|
|
||||||
print(f"Input length of the prompt: {input_len} tokens")
|
|
||||||
requests = [
|
|
||||||
SampleRequest(prompt=prompt,
|
|
||||||
prompt_len=input_len,
|
|
||||||
expected_output_len=args.output_len,
|
|
||||||
schema=schema,
|
|
||||||
structure_type=args.structure_type)
|
|
||||||
for _ in range(args.num_prompts)
|
|
||||||
]
|
|
||||||
|
|
||||||
elif args.dataset == "regex":
|
|
||||||
regex = r"\w+@\w+\.com\n"
|
|
||||||
args.regex = regex
|
|
||||||
prompt = "Generate an email address for Alan Turing, \
|
|
||||||
who works in Enigma. End in .com and new line. \
|
|
||||||
Example result: alan.turing@enigma.com\n"
|
|
||||||
|
|
||||||
input_len = len(tokenizer(prompt).input_ids)
|
|
||||||
print(f"Input length of the prompt: {input_len} tokens")
|
|
||||||
requests = [
|
|
||||||
SampleRequest(prompt=prompt,
|
|
||||||
prompt_len=input_len,
|
|
||||||
expected_output_len=args.output_len,
|
|
||||||
schema=regex,
|
|
||||||
structure_type=args.structure_type)
|
|
||||||
for _ in range(args.num_prompts)
|
|
||||||
]
|
|
||||||
|
|
||||||
elif args.dataset == "choice":
|
|
||||||
choice = ["Positive", "Negative"]
|
|
||||||
args.choice = choice
|
|
||||||
prompt = "Classify this sentiment: vLLM is wonderful!"
|
|
||||||
input_len = len(tokenizer(prompt).input_ids)
|
|
||||||
print(f"Input length of the prompt: {input_len} tokens")
|
|
||||||
requests = [
|
|
||||||
SampleRequest(prompt=prompt,
|
|
||||||
prompt_len=input_len,
|
|
||||||
expected_output_len=args.output_len,
|
|
||||||
schema=choice,
|
|
||||||
structure_type=args.structure_type)
|
|
||||||
for _ in range(args.num_prompts)
|
|
||||||
]
|
|
||||||
|
|
||||||
elif args.dataset == "xgrammar_bench":
|
|
||||||
args.warmup = False
|
|
||||||
requests: list[SampleRequest] = []
|
|
||||||
dataset = datasets.load_dataset("NousResearch/json-mode-eval",
|
|
||||||
split="train")
|
|
||||||
print(f"dataset has {len(dataset)} entries")
|
|
||||||
len_dataset = len(dataset)
|
|
||||||
for data_point_idx in range(args.num_prompts):
|
|
||||||
idx = data_point_idx
|
|
||||||
while idx >= len_dataset:
|
|
||||||
idx -= len_dataset
|
|
||||||
schema = dataset["schema"][idx]
|
|
||||||
prompt = tokenizer.apply_chat_template(dataset["prompt"][idx],
|
|
||||||
tokenize=False)
|
|
||||||
input_len = len(tokenizer(prompt).input_ids)
|
|
||||||
completion = dataset["completion"][idx]
|
|
||||||
|
|
||||||
requests.append(
|
|
||||||
SampleRequest(prompt=prompt,
|
|
||||||
prompt_len=input_len,
|
|
||||||
expected_output_len=args.output_len,
|
|
||||||
schema=schema,
|
|
||||||
completion=completion))
|
|
||||||
|
|
||||||
return requests
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate(ret, args):
|
|
||||||
|
|
||||||
def _eval_correctness_json(expected, actual):
|
|
||||||
# extract json string from string using regex
|
|
||||||
import re
|
|
||||||
actual = actual.replace('\n', '').replace(' ', '').strip()
|
|
||||||
try:
|
|
||||||
actual = re.search(r'\{.*\}', actual).group()
|
|
||||||
actual = json.loads(actual)
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _eval_correctness_choice(expected, actual):
|
|
||||||
return actual in args.choice
|
|
||||||
|
|
||||||
def _eval_correctness_regex(expected, actual):
|
|
||||||
import re
|
|
||||||
return re.match(args.regex, actual) is not None
|
|
||||||
|
|
||||||
def _eval_correctness(expected, actual):
|
|
||||||
if args.structure_type == 'json':
|
|
||||||
return _eval_correctness_json(expected, actual)
|
|
||||||
elif args.structure_type == 'regex':
|
|
||||||
return _eval_correctness_regex(expected, actual)
|
|
||||||
elif args.structure_type == 'choice':
|
|
||||||
return _eval_correctness_choice(expected, actual)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
scores = []
|
|
||||||
for res in ret:
|
|
||||||
score = _eval_correctness(res['expected'], res['generated'])
|
|
||||||
res['correctness'] = score
|
|
||||||
scores.append(score)
|
|
||||||
|
|
||||||
not_none_scores = [score for score in scores if score is not None]
|
|
||||||
|
|
||||||
return (sum(not_none_scores) / len(not_none_scores) *
|
|
||||||
100) if len(not_none_scores) > 0 else None
|
|
||||||
|
|
||||||
|
|
||||||
def main(args: argparse.Namespace):
|
|
||||||
print(args)
|
|
||||||
random.seed(args.seed)
|
|
||||||
|
|
||||||
# async engine is working for 'regex', 'choice' and 'grammar'
|
|
||||||
if args.dataset == 'grammar':
|
|
||||||
args.structure_type = 'grammar'
|
|
||||||
args.async_engine = False
|
|
||||||
elif args.dataset == 'regex':
|
|
||||||
args.structure_type = 'regex'
|
|
||||||
args.async_engine = False
|
|
||||||
elif args.dataset == 'choice':
|
|
||||||
args.structure_type = 'choice'
|
|
||||||
args.async_engine = False
|
|
||||||
else:
|
|
||||||
args.structure_type = 'json'
|
|
||||||
|
|
||||||
if args.no_guided_decoding:
|
|
||||||
args.guided_decoding_ratio = 0
|
|
||||||
if args.save_results:
|
|
||||||
result_file_name = f'{args.guided_decoding_ratio}guided'
|
|
||||||
result_file_name += f"_{args.model.split('/')[-1]}"
|
|
||||||
result_file_name += f"_{args.dataset}"
|
|
||||||
result_file_name += f"_{args.num_prompts}"
|
|
||||||
result_file_name += f"_out{args.output_len}"
|
|
||||||
result_file_name += f"_async{args.async_engine}"
|
|
||||||
result_file_name += f"_warmup{args.warmup}"
|
|
||||||
result_file_name += f"_chunkedprefill{args.enable_chunked_prefill}"
|
|
||||||
result_file_name += ".txt"
|
|
||||||
else:
|
|
||||||
result_file_name = None
|
|
||||||
|
|
||||||
# Synthesize a prompt with the given input length.
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
args.tokenizer, trust_remote_code=args.trust_remote_code)
|
|
||||||
requests = sample_requests(tokenizer, args)
|
|
||||||
|
|
||||||
if args.async_engine:
|
|
||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
|
||||||
elapsed_time, ret, (first_latency, next_latency) = uvloop.run(
|
|
||||||
run_vllm_async(requests, engine_args, args.n,
|
|
||||||
args.guided_decoding_ratio, args.warmup,
|
|
||||||
args.disable_frontend_multiprocessing))
|
|
||||||
else:
|
|
||||||
engine_args = EngineArgs.from_cli_args(args)
|
|
||||||
elapsed_time, ret = run_vllm(requests, engine_args, args.n,
|
|
||||||
args.guided_decoding_ratio, args.warmup)
|
|
||||||
first_latency, next_latency = None, None
|
|
||||||
|
|
||||||
score = evaluate(ret, args)
|
|
||||||
total_num_tokens = sum(request.prompt_len + request.expected_output_len
|
|
||||||
for request in requests)
|
|
||||||
total_output_tokens = sum(request.expected_output_len
|
|
||||||
for request in requests)
|
|
||||||
if first_latency is not None:
|
|
||||||
latency_breakdown = "\nFirst token latency(msecs):\n"
|
|
||||||
latency_breakdown += f"{first_latency.describe()}"
|
|
||||||
latency_breakdown += "\nNext token latency(msecs):\n"
|
|
||||||
latency_breakdown += f"{next_latency.describe()}"
|
|
||||||
print(
|
|
||||||
f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
|
||||||
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
|
|
||||||
f"{total_output_tokens / elapsed_time:.2f} output tokens/s",
|
|
||||||
f"Correct rate is {score} %",
|
|
||||||
f"{latency_breakdown if first_latency is not None else ''}")
|
|
||||||
|
|
||||||
# Output JSON results if specified
|
|
||||||
if args.output_json or result_file_name:
|
|
||||||
results = {
|
|
||||||
"elapsed_time": elapsed_time,
|
|
||||||
"num_requests": len(requests),
|
|
||||||
"total_num_tokens": total_num_tokens,
|
|
||||||
"total_output_tokens": total_output_tokens,
|
|
||||||
"requests_per_second": len(requests) / elapsed_time,
|
|
||||||
"tokens_per_second": f"{total_num_tokens / elapsed_time:.2f}",
|
|
||||||
"output_tokens_per_second":
|
|
||||||
f"{total_output_tokens / elapsed_time:.2f}",
|
|
||||||
"correct_rate(%)": score
|
|
||||||
}
|
|
||||||
results = {"outputs": ret, **results}
|
|
||||||
if first_latency is not None:
|
|
||||||
results["first_token_latency(msecs)"] = first_latency.describe(
|
|
||||||
).to_dict()
|
|
||||||
results["next_token_latency(msecs)"] = next_latency.describe(
|
|
||||||
).to_dict()
|
|
||||||
if args.output_json:
|
|
||||||
with open(args.output_json, "w") as f:
|
|
||||||
json.dump(results, f, indent=4)
|
|
||||||
elif result_file_name:
|
|
||||||
with open(result_file_name, "w") as f:
|
|
||||||
json.dump(results, f, indent=4)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = FlexibleArgumentParser(description="Benchmark guided decoding.")
|
|
||||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
|
||||||
|
|
||||||
parser.add_argument("--output-len",
|
|
||||||
type=int,
|
|
||||||
default=512,
|
|
||||||
help="Output length for each request. Overrides the "
|
|
||||||
"output length from the dataset.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--dataset",
|
|
||||||
default='json',
|
|
||||||
choices=['json', 'grammar', 'regex', 'choice', 'xgrammar_bench'])
|
|
||||||
parser.add_argument("--json_schema_path",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="Path to json schema.")
|
|
||||||
parser.add_argument("--n",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Number of generated sequences per prompt.")
|
|
||||||
parser.add_argument("--num-prompts",
|
|
||||||
type=int,
|
|
||||||
default=10,
|
|
||||||
help="Number of prompts to process.")
|
|
||||||
parser.add_argument(
|
|
||||||
'--output-json',
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help='Path to save the throughput results in JSON format.')
|
|
||||||
parser.add_argument("--async-engine",
|
|
||||||
action='store_true',
|
|
||||||
default=False,
|
|
||||||
help="Use vLLM async engine rather than LLM class.")
|
|
||||||
parser.add_argument("--no-guided-decoding",
|
|
||||||
action='store_true',
|
|
||||||
default=False,
|
|
||||||
help="Whether to disable JSON decoding or not.")
|
|
||||||
parser.add_argument("--guided-decoding-ratio",
|
|
||||||
type=float,
|
|
||||||
default=1.0,
|
|
||||||
help="Ratio of Guided Decoding requests")
|
|
||||||
parser.add_argument("--disable-frontend-multiprocessing",
|
|
||||||
action='store_true',
|
|
||||||
default=False,
|
|
||||||
help="Disable decoupled async engine frontend.")
|
|
||||||
parser.add_argument("--warmup",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Run warmup prompts before benchmark.")
|
|
||||||
parser.add_argument("--save-results",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="save output results.")
|
|
||||||
args = parser.parse_args()
|
|
||||||
if args.tokenizer is None:
|
|
||||||
args.tokenizer = args.model
|
|
||||||
main(args)
|
|
@ -1,5 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
r"""Benchmark online serving throughput with guided decoding.
|
r"""Benchmark online serving throughput with structured outputs.
|
||||||
|
|
||||||
On the server side, run one of the following commands:
|
On the server side, run one of the following commands:
|
||||||
(vLLM OpenAI API server)
|
(vLLM OpenAI API server)
|
||||||
@ -9,12 +9,12 @@ On the server side, run one of the following commands:
|
|||||||
./launch_tgi_server.sh <your_model> <max_batch_total_tokens>
|
./launch_tgi_server.sh <your_model> <max_batch_total_tokens>
|
||||||
|
|
||||||
On the client side, run:
|
On the client side, run:
|
||||||
python benchmarks/benchmark_serving_guided.py \
|
python benchmarks/benchmark_serving_structured_output.py \
|
||||||
--backend <backend> \
|
--backend <backend> \
|
||||||
--model <your_model> \
|
--model <your_model> \
|
||||||
--dataset json \
|
--dataset json \
|
||||||
--guided-decoding-ratio 1.0 \
|
--structured-output-ratio 1.0 \
|
||||||
--guided-decoding-backend xgrammar \
|
--structured-output-backend xgrammar \
|
||||||
--request-rate 10 \
|
--request-rate 10 \
|
||||||
--num-prompts 1000
|
--num-prompts 1000
|
||||||
|
|
||||||
@ -52,6 +52,9 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||||
|
|
||||||
|
from vllm.v1.structured_output.utils import (
|
||||||
|
has_xgrammar_unsupported_json_features)
|
||||||
|
|
||||||
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
||||||
|
|
||||||
|
|
||||||
@ -191,7 +194,17 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
|||||||
requests: list[SampleRequest] = []
|
requests: list[SampleRequest] = []
|
||||||
dataset = datasets.load_dataset("NousResearch/json-mode-eval",
|
dataset = datasets.load_dataset("NousResearch/json-mode-eval",
|
||||||
split="train")
|
split="train")
|
||||||
print(f"dataset has {len(dataset)} entries")
|
full_dataset_len = len(dataset)
|
||||||
|
|
||||||
|
def _filter_func(item):
|
||||||
|
import json
|
||||||
|
schema = json.loads(item["schema"])
|
||||||
|
return not has_xgrammar_unsupported_json_features(schema)
|
||||||
|
|
||||||
|
dataset = dataset.filter(_filter_func)
|
||||||
|
num_filtered_out = full_dataset_len - len(dataset)
|
||||||
|
print(f"dataset has {len(dataset)} entries after filtering "
|
||||||
|
f"out {num_filtered_out} entries with unsupported features")
|
||||||
len_dataset = len(dataset)
|
len_dataset = len(dataset)
|
||||||
for data_point_idx in range(args.num_prompts):
|
for data_point_idx in range(args.num_prompts):
|
||||||
idx = data_point_idx
|
idx = data_point_idx
|
||||||
@ -220,21 +233,21 @@ async def get_request(
|
|||||||
burstiness: float = 1.0,
|
burstiness: float = 1.0,
|
||||||
) -> AsyncGenerator[tuple[int, SampleRequest], None]:
|
) -> AsyncGenerator[tuple[int, SampleRequest], None]:
|
||||||
"""
|
"""
|
||||||
Asynchronously generates requests at a specified rate
|
Asynchronously generates requests at a specified rate
|
||||||
with OPTIONAL burstiness.
|
with OPTIONAL burstiness.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_requests:
|
input_requests:
|
||||||
A list of input requests, each represented as a tuple.
|
A list of input requests, each represented as a tuple.
|
||||||
request_rate:
|
request_rate:
|
||||||
The rate at which requests are generated (requests/s).
|
The rate at which requests are generated (requests/s).
|
||||||
burstiness (optional):
|
burstiness (optional):
|
||||||
The burstiness factor of the request generation.
|
The burstiness factor of the request generation.
|
||||||
Only takes effect when request_rate is not inf.
|
Only takes effect when request_rate is not inf.
|
||||||
Default value is 1, which follows a Poisson process.
|
Default value is 1, which follows a Poisson process.
|
||||||
Otherwise, the request intervals follow a gamma distribution.
|
Otherwise, the request intervals follow a gamma distribution.
|
||||||
A lower burstiness value (0 < burstiness < 1) results
|
A lower burstiness value (0 < burstiness < 1) results
|
||||||
in more bursty requests, while a higher burstiness value
|
in more bursty requests, while a higher burstiness value
|
||||||
(burstiness > 1) results in a more uniform arrival of requests.
|
(burstiness > 1) results in a more uniform arrival of requests.
|
||||||
"""
|
"""
|
||||||
input_requests = iter(input_requests)
|
input_requests = iter(input_requests)
|
||||||
@ -378,8 +391,8 @@ async def benchmark(
|
|||||||
selected_percentiles: list[str],
|
selected_percentiles: list[str],
|
||||||
ignore_eos: bool,
|
ignore_eos: bool,
|
||||||
max_concurrency: Optional[int],
|
max_concurrency: Optional[int],
|
||||||
guided_decoding_ratio: float,
|
structured_output_ratio: float,
|
||||||
guided_decoding_backend: str,
|
structured_output_backend: str,
|
||||||
goodput_config_dict: Optional[dict[str, float]] = None,
|
goodput_config_dict: Optional[dict[str, float]] = None,
|
||||||
):
|
):
|
||||||
if backend in ASYNC_REQUEST_FUNCS:
|
if backend in ASYNC_REQUEST_FUNCS:
|
||||||
@ -391,16 +404,18 @@ async def benchmark(
|
|||||||
extra_body = {}
|
extra_body = {}
|
||||||
# Add the schema to the extra_body
|
# Add the schema to the extra_body
|
||||||
extra_body[request.structure_type] = request.schema
|
extra_body[request.structure_type] = request.schema
|
||||||
# Add the specific guided_decoding_backend
|
# Add the specific structured_output_backend
|
||||||
extra_body["guided_decoding_backend"] = guided_decoding_backend
|
extra_body["guided_decoding_backend"] = structured_output_backend
|
||||||
return extra_body
|
return extra_body
|
||||||
|
|
||||||
print("Starting initial single prompt test run...")
|
print("Starting initial single prompt test run...")
|
||||||
guided_decoding_req_idx = random.sample(
|
structured_output_req_idx = random.sample(
|
||||||
range(len(input_requests)),
|
range(len(input_requests)),
|
||||||
int(len(input_requests) * guided_decoding_ratio))
|
int(len(input_requests) * structured_output_ratio))
|
||||||
|
|
||||||
test_request = input_requests[0]
|
test_request = input_requests[0]
|
||||||
|
test_req_extra_body = (prepare_extra_body(test_request)
|
||||||
|
if 0 in structured_output_req_idx else None)
|
||||||
test_input = RequestFuncInput(
|
test_input = RequestFuncInput(
|
||||||
model=model_id,
|
model=model_id,
|
||||||
prompt=test_request.prompt,
|
prompt=test_request.prompt,
|
||||||
@ -408,7 +423,7 @@ async def benchmark(
|
|||||||
prompt_len=test_request.prompt_len,
|
prompt_len=test_request.prompt_len,
|
||||||
output_len=test_request.expected_output_len,
|
output_len=test_request.expected_output_len,
|
||||||
ignore_eos=ignore_eos,
|
ignore_eos=ignore_eos,
|
||||||
extra_body=prepare_extra_body(test_request),
|
extra_body=test_req_extra_body,
|
||||||
)
|
)
|
||||||
test_output = await request_func(request_func_input=test_input)
|
test_output = await request_func(request_func_input=test_input)
|
||||||
if not test_output.success:
|
if not test_output.success:
|
||||||
@ -427,7 +442,7 @@ async def benchmark(
|
|||||||
prompt_len=test_request.prompt_len,
|
prompt_len=test_request.prompt_len,
|
||||||
output_len=test_request.expected_output_len,
|
output_len=test_request.expected_output_len,
|
||||||
ignore_eos=ignore_eos,
|
ignore_eos=ignore_eos,
|
||||||
extra_body=prepare_extra_body(test_request),
|
extra_body=test_req_extra_body,
|
||||||
)
|
)
|
||||||
profile_output = await request_func(request_func_input=profile_input)
|
profile_output = await request_func(request_func_input=profile_input)
|
||||||
if profile_output.success:
|
if profile_output.success:
|
||||||
@ -465,7 +480,7 @@ async def benchmark(
|
|||||||
async for i, request in get_request(input_requests, request_rate,
|
async for i, request in get_request(input_requests, request_rate,
|
||||||
burstiness):
|
burstiness):
|
||||||
extra_body = prepare_extra_body(
|
extra_body = prepare_extra_body(
|
||||||
request) if i in guided_decoding_req_idx else None
|
request) if i in structured_output_req_idx else None
|
||||||
request_func_input = RequestFuncInput(
|
request_func_input = RequestFuncInput(
|
||||||
model=model_id,
|
model=model_id,
|
||||||
prompt=request.prompt,
|
prompt=request.prompt,
|
||||||
@ -708,10 +723,10 @@ def main(args: argparse.Namespace):
|
|||||||
else:
|
else:
|
||||||
args.structure_type = 'guided_json'
|
args.structure_type = 'guided_json'
|
||||||
|
|
||||||
if args.no_guided_decoding:
|
if args.no_structured_output:
|
||||||
args.guided_decoding_ratio = 0
|
args.structured_output_ratio = 0
|
||||||
if args.save_results:
|
if args.save_results:
|
||||||
result_file_name = f'{args.guided_decoding_ratio}guided'
|
result_file_name = f'{args.structured_output_ratio}guided'
|
||||||
result_file_name += f"_{backend}"
|
result_file_name += f"_{backend}"
|
||||||
result_file_name += f"_{args.request_rate}qps"
|
result_file_name += f"_{args.request_rate}qps"
|
||||||
result_file_name += f"_{args.model.split('/')[-1]}"
|
result_file_name += f"_{args.model.split('/')[-1]}"
|
||||||
@ -744,8 +759,8 @@ def main(args: argparse.Namespace):
|
|||||||
],
|
],
|
||||||
ignore_eos=args.ignore_eos,
|
ignore_eos=args.ignore_eos,
|
||||||
max_concurrency=args.max_concurrency,
|
max_concurrency=args.max_concurrency,
|
||||||
guided_decoding_ratio=args.guided_decoding_ratio,
|
structured_output_ratio=args.structured_output_ratio,
|
||||||
guided_decoding_backend=args.guided_decoding_backend,
|
structured_output_backend=args.structured_output_backend,
|
||||||
goodput_config_dict=goodput_config_dict,
|
goodput_config_dict=goodput_config_dict,
|
||||||
))
|
))
|
||||||
|
|
||||||
@ -943,19 +958,19 @@ if __name__ == "__main__":
|
|||||||
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
|
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
|
||||||
"and the blog: https://hao-ai-lab.github.io/blogs/distserve")
|
"and the blog: https://hao-ai-lab.github.io/blogs/distserve")
|
||||||
|
|
||||||
parser.add_argument("--no-guided-decoding",
|
parser.add_argument("--no-structured-output",
|
||||||
action='store_true',
|
action='store_true',
|
||||||
default=False,
|
default=False,
|
||||||
help="Whether to disable JSON decoding or not.")
|
help="Whether to disable JSON decoding or not.")
|
||||||
parser.add_argument("--guided-decoding-ratio",
|
parser.add_argument("--structured-output-ratio",
|
||||||
type=float,
|
type=float,
|
||||||
default=1.0,
|
default=1.0,
|
||||||
help="Ratio of Guided Decoding requests")
|
help="Ratio of Structured Outputs requests")
|
||||||
parser.add_argument("--guided-decoding-backend",
|
parser.add_argument("--structured-output-backend",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["outlines", "lm-format-enforcer", "xgrammar"],
|
choices=["outlines", "lm-format-enforcer", "xgrammar"],
|
||||||
default="xgrammar",
|
default="xgrammar",
|
||||||
help="Backend to use for guided decoding")
|
help="Backend to use for structured outputs")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
64
benchmarks/run_structured_output_benchmark.sh
Executable file
64
benchmarks/run_structured_output_benchmark.sh
Executable file
@ -0,0 +1,64 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Define the model to use
|
||||||
|
MODEL=${1:-"Qwen/Qwen2.5-7B-Instruct"}
|
||||||
|
|
||||||
|
# Define the backend to use
|
||||||
|
BACKEND=${2:-"vllm"}
|
||||||
|
|
||||||
|
# Define the dataset to use
|
||||||
|
DATASET=${3:-"xgrammar_bench"}
|
||||||
|
|
||||||
|
# Define the guided decoding backend
|
||||||
|
GUIDED_BACKEND=${4:-"xgrammar"}
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
OUTPUT_DIR=${5:-"$SCRIPT_DIR/structured_output_benchmark_results"}
|
||||||
|
|
||||||
|
GUIDED_RATIO=${6:-0.5}
|
||||||
|
|
||||||
|
# Create output directory if it doesn't exist
|
||||||
|
mkdir -p "$OUTPUT_DIR"
|
||||||
|
|
||||||
|
# Define QPS values to test
|
||||||
|
QPS_VALUES=(70 60 50 25 20 15 10)
|
||||||
|
|
||||||
|
# Common parameters
|
||||||
|
COMMON_PARAMS="--backend $BACKEND \
|
||||||
|
--model $MODEL \
|
||||||
|
--dataset $DATASET \
|
||||||
|
--structured-output-backend $GUIDED_BACKEND \
|
||||||
|
--structured-output-ratio $GUIDED_RATIO \
|
||||||
|
--save-results \
|
||||||
|
--result-dir $OUTPUT_DIR"
|
||||||
|
|
||||||
|
echo "Starting structured output benchmark with model: $MODEL"
|
||||||
|
echo "Backend: $BACKEND"
|
||||||
|
echo "Dataset: $DATASET"
|
||||||
|
echo "Structured output backend: $GUIDED_BACKEND"
|
||||||
|
echo "Results will be saved to: $OUTPUT_DIR"
|
||||||
|
echo "----------------------------------------"
|
||||||
|
|
||||||
|
# Run benchmarks with different QPS values
|
||||||
|
for qps in "${QPS_VALUES[@]}"; do
|
||||||
|
echo "Running benchmark with QPS: $qps"
|
||||||
|
|
||||||
|
# Get git hash and branch for the filename
|
||||||
|
GIT_HASH=$(git rev-parse --short HEAD 2>/dev/null || echo "unknown")
|
||||||
|
GIT_BRANCH=$(git rev-parse --abbrev-ref HEAD 2>/dev/null || echo "unknown")
|
||||||
|
|
||||||
|
# Construct filename for this run
|
||||||
|
FILENAME="${GUIDED_BACKEND}_${BACKEND}_${qps}qps_$(basename $MODEL)_${DATASET}_${GIT_HASH}.json"
|
||||||
|
|
||||||
|
# Run the benchmark
|
||||||
|
python "$SCRIPT_DIR/benchmark_serving_structured_output.py" $COMMON_PARAMS \
|
||||||
|
--request-rate $qps \
|
||||||
|
--result-filename "$FILENAME" \
|
||||||
|
--port ${PORT:-8000}
|
||||||
|
|
||||||
|
echo "Completed benchmark with QPS: $qps"
|
||||||
|
echo "----------------------------------------"
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "All benchmarks completed!"
|
||||||
|
echo "Results saved to: $OUTPUT_DIR"
|
@ -1,113 +1,25 @@
|
|||||||
{
|
{
|
||||||
"$schema":
|
"type": "array",
|
||||||
"https://json-schema.org/draft/2020-12/schema",
|
"items": {
|
||||||
"title":
|
"type": "object",
|
||||||
"User Profile",
|
|
||||||
"type":
|
|
||||||
"object",
|
|
||||||
"properties": {
|
"properties": {
|
||||||
"userId": {
|
"name": { "type": "string" },
|
||||||
"type": "string",
|
"race": { "type": "string" },
|
||||||
"description": "Unique identifier for the user."
|
"class": { "type": "string" },
|
||||||
},
|
"level": { "type": "integer" },
|
||||||
"personalInfo": {
|
"background": { "type": "string" },
|
||||||
"type": "object",
|
"alignment": { "type": "string" },
|
||||||
"properties": {
|
"backstory": { "type": "string" }
|
||||||
"firstName": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The user's first name."
|
|
||||||
},
|
|
||||||
"lastName": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The user's last name."
|
|
||||||
},
|
|
||||||
"age": {
|
|
||||||
"type": "integer",
|
|
||||||
"minimum": 0,
|
|
||||||
"description": "The user's age."
|
|
||||||
},
|
|
||||||
"phoneNumbers": {
|
|
||||||
"type":
|
|
||||||
"array",
|
|
||||||
"items": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"type": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": ["home", "work", "mobile"],
|
|
||||||
"description": "Type of phone number."
|
|
||||||
},
|
|
||||||
"number": {
|
|
||||||
"type": "string",
|
|
||||||
"pattern": "^\\+?[1-9]\\d{1,14}$",
|
|
||||||
"description": "Phone number in E.164 format."
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["type", "number"]
|
|
||||||
},
|
|
||||||
"description":
|
|
||||||
"List of phone numbers associated with the user."
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["firstName", "lastName"]
|
|
||||||
},
|
|
||||||
"address": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"street": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Street address."
|
|
||||||
},
|
|
||||||
"city": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "City name."
|
|
||||||
},
|
|
||||||
"state": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "State or province."
|
|
||||||
},
|
|
||||||
"postalCode": {
|
|
||||||
"type": "string",
|
|
||||||
"pattern": "^\\d{5}(-\\d{4})?$",
|
|
||||||
"description": "Postal code."
|
|
||||||
},
|
|
||||||
"country": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Country name."
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["street", "city", "state", "postalCode", "country"]
|
|
||||||
},
|
|
||||||
"preferences": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"newsletterSubscribed": {
|
|
||||||
"type":
|
|
||||||
"boolean",
|
|
||||||
"description":
|
|
||||||
"Indicates if the user is subscribed to the newsletter."
|
|
||||||
},
|
|
||||||
"favoriteCategories": {
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"type": "string"
|
|
||||||
},
|
|
||||||
"description": "List of user's favorite categories."
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["newsletterSubscribed"]
|
|
||||||
},
|
|
||||||
"accountStatus": {
|
|
||||||
"type": "string",
|
|
||||||
"enum": ["active", "inactive", "suspended"],
|
|
||||||
"description": "Current status of the user's account."
|
|
||||||
},
|
|
||||||
"registrationDate": {
|
|
||||||
"type": "string",
|
|
||||||
"format": "date-time",
|
|
||||||
"description": "ISO 8601 formatted date-time of user registration."
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
"required":
|
"required": [
|
||||||
["userId", "personalInfo", "address", "accountStatus", "registrationDate"]
|
"name",
|
||||||
}
|
"race",
|
||||||
|
"class",
|
||||||
|
"level",
|
||||||
|
"background",
|
||||||
|
"alignment",
|
||||||
|
"backstory"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
|
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
|
||||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput
|
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
from vllm.v1.request import Request, RequestStatus
|
from vllm.v1.request import Request, RequestStatus
|
||||||
|
from vllm.v1.structured_output import StructuredOutputManager
|
||||||
|
|
||||||
EOS_TOKEN_ID = 50256
|
EOS_TOKEN_ID = 50256
|
||||||
|
|
||||||
@ -36,13 +37,21 @@ def create_scheduler(
|
|||||||
swap_space=0,
|
swap_space=0,
|
||||||
cache_dtype="auto",
|
cache_dtype="auto",
|
||||||
)
|
)
|
||||||
|
vllm_config = VllmConfig(
|
||||||
|
scheduler_config=scheduler_config,
|
||||||
|
model_config=model_config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
)
|
||||||
cache_config.num_gpu_blocks = 10000
|
cache_config.num_gpu_blocks = 10000
|
||||||
return Scheduler(scheduler_config,
|
return Scheduler(
|
||||||
model_config,
|
scheduler_config,
|
||||||
cache_config,
|
model_config,
|
||||||
speculative_config=None,
|
cache_config,
|
||||||
lora_config=None,
|
speculative_config=None,
|
||||||
log_stats=True)
|
lora_config=None,
|
||||||
|
log_stats=True,
|
||||||
|
structured_output_manager=StructuredOutputManager(vllm_config),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_requests(
|
def create_requests(
|
||||||
@ -249,7 +258,9 @@ def test_stop_via_update_from_output():
|
|||||||
},
|
},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=0,
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_input_ids=[])
|
free_encoder_input_ids=[],
|
||||||
|
structured_output_request_ids={},
|
||||||
|
grammar_bitmask=None)
|
||||||
|
|
||||||
model_output = ModelRunnerOutput(
|
model_output = ModelRunnerOutput(
|
||||||
req_ids=[req.request_id for req in requests],
|
req_ids=[req.request_id for req in requests],
|
||||||
@ -299,7 +310,9 @@ def test_stop_via_update_from_output():
|
|||||||
},
|
},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=0,
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_input_ids=[])
|
free_encoder_input_ids=[],
|
||||||
|
structured_output_request_ids={},
|
||||||
|
grammar_bitmask=None)
|
||||||
|
|
||||||
model_output = ModelRunnerOutput(
|
model_output = ModelRunnerOutput(
|
||||||
req_ids=[req.request_id for req in requests],
|
req_ids=[req.request_id for req in requests],
|
||||||
@ -347,7 +360,9 @@ def test_stop_via_update_from_output():
|
|||||||
},
|
},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=0,
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_input_ids=[])
|
free_encoder_input_ids=[],
|
||||||
|
structured_output_request_ids={},
|
||||||
|
grammar_bitmask=None)
|
||||||
|
|
||||||
model_output = ModelRunnerOutput(
|
model_output = ModelRunnerOutput(
|
||||||
req_ids=[req.request_id for req in requests],
|
req_ids=[req.request_id for req in requests],
|
||||||
@ -392,7 +407,9 @@ def test_stop_via_update_from_output():
|
|||||||
},
|
},
|
||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=0,
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_input_ids=[])
|
free_encoder_input_ids=[],
|
||||||
|
structured_output_request_ids={},
|
||||||
|
grammar_bitmask=None)
|
||||||
|
|
||||||
model_output = ModelRunnerOutput(
|
model_output = ModelRunnerOutput(
|
||||||
req_ids=[requests[0].request_id],
|
req_ids=[requests[0].request_id],
|
||||||
|
@ -29,6 +29,7 @@ def sample_regex():
|
|||||||
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
|
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
|
||||||
|
|
||||||
|
|
||||||
|
# Note: Ensure this only uses attributes compatible with xgrammar
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_json_schema():
|
def sample_json_schema():
|
||||||
return {
|
return {
|
||||||
@ -44,9 +45,7 @@ def sample_json_schema():
|
|||||||
"type": "array",
|
"type": "array",
|
||||||
"items": {
|
"items": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"maxLength": 10
|
}
|
||||||
},
|
|
||||||
"minItems": 3
|
|
||||||
},
|
},
|
||||||
"work_history": {
|
"work_history": {
|
||||||
"type": "array",
|
"type": "array",
|
||||||
@ -71,8 +70,9 @@ def sample_json_schema():
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# A schema unsupported by xgrammar
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_complex_json_schema():
|
def unsupported_json_schema():
|
||||||
return {
|
return {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
@ -150,7 +150,19 @@ def sample_guided_choice():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_sql_statements():
|
def sample_sql_ebnf():
|
||||||
|
return """
|
||||||
|
root ::= select_statement
|
||||||
|
select_statement ::= "SELECT" column "from" table "where" condition
|
||||||
|
column ::= "col_1" | "col_2"
|
||||||
|
table ::= "table_1" | "table_2"
|
||||||
|
condition ::= column "=" number
|
||||||
|
number ::= "1" | "2"
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_sql_lark():
|
||||||
return ("""
|
return ("""
|
||||||
start: select_statement
|
start: select_statement
|
||||||
select_statement: "SELECT" column "from" table "where" condition
|
select_statement: "SELECT" column "from" table "where" condition
|
||||||
|
0
tests/v1/entrypoints/llm/__init__.py
Normal file
0
tests/v1/entrypoints/llm/__init__.py
Normal file
269
tests/v1/entrypoints/llm/test_struct_output_generate.py
Normal file
269
tests/v1/entrypoints/llm/test_struct_output_generate.py
Normal file
@ -0,0 +1,269 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import jsonschema
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.entrypoints.llm import LLM
|
||||||
|
from vllm.outputs import RequestOutput
|
||||||
|
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||||
|
|
||||||
|
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||||
|
GUIDED_DECODING_BACKENDS_V1 = ["xgrammar"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_global_cleanup
|
||||||
|
@pytest.mark.parametrize("guided_decoding_backend",
|
||||||
|
GUIDED_DECODING_BACKENDS_V1)
|
||||||
|
def test_guided_json_completion(monkeypatch, sample_json_schema,
|
||||||
|
guided_decoding_backend: str):
|
||||||
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
|
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||||
|
sampling_params = SamplingParams(temperature=1.0,
|
||||||
|
max_tokens=1000,
|
||||||
|
guided_decoding=GuidedDecodingParams(
|
||||||
|
json=sample_json_schema,
|
||||||
|
backend=guided_decoding_backend))
|
||||||
|
outputs = llm.generate(prompts=[
|
||||||
|
f"Give an example JSON for an employee profile "
|
||||||
|
f"that fits this schema: {sample_json_schema}"
|
||||||
|
] * 2,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
use_tqdm=True)
|
||||||
|
|
||||||
|
assert outputs is not None
|
||||||
|
|
||||||
|
for output in outputs:
|
||||||
|
assert output is not None
|
||||||
|
assert isinstance(output, RequestOutput)
|
||||||
|
prompt = output.prompt
|
||||||
|
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
assert generated_text is not None
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
output_json = json.loads(generated_text)
|
||||||
|
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_global_cleanup
|
||||||
|
@pytest.mark.parametrize("guided_decoding_backend",
|
||||||
|
GUIDED_DECODING_BACKENDS_V1)
|
||||||
|
def test_guided_json_object(monkeypatch, guided_decoding_backend: str):
|
||||||
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
|
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||||
|
sampling_params = SamplingParams(temperature=1.0,
|
||||||
|
max_tokens=100,
|
||||||
|
n=2,
|
||||||
|
guided_decoding=GuidedDecodingParams(
|
||||||
|
json_object=True,
|
||||||
|
backend=guided_decoding_backend))
|
||||||
|
|
||||||
|
outputs = llm.generate(
|
||||||
|
prompts=("Generate a JSON object with curly braces for a person with "
|
||||||
|
"name and age fields for John Smith who is 31 years old."),
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
use_tqdm=True)
|
||||||
|
|
||||||
|
assert outputs is not None
|
||||||
|
for output in outputs:
|
||||||
|
assert output is not None
|
||||||
|
assert isinstance(output, RequestOutput)
|
||||||
|
|
||||||
|
for i in range(2):
|
||||||
|
generated_text = output.outputs[i].text
|
||||||
|
print(generated_text)
|
||||||
|
assert generated_text is not None
|
||||||
|
|
||||||
|
# Parse to verify it is valid JSON
|
||||||
|
parsed_json = json.loads(generated_text)
|
||||||
|
assert isinstance(parsed_json, dict)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_global_cleanup
|
||||||
|
@pytest.mark.parametrize("guided_decoding_backend",
|
||||||
|
GUIDED_DECODING_BACKENDS_V1)
|
||||||
|
def test_guided_json_unsupported_schema(monkeypatch, unsupported_json_schema,
|
||||||
|
guided_decoding_backend: str):
|
||||||
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
|
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||||
|
sampling_params = SamplingParams(temperature=1.0,
|
||||||
|
max_tokens=1000,
|
||||||
|
guided_decoding=GuidedDecodingParams(
|
||||||
|
json=unsupported_json_schema,
|
||||||
|
backend=guided_decoding_backend))
|
||||||
|
with pytest.raises(ValueError,
|
||||||
|
match="The provided JSON schema contains features "
|
||||||
|
"not supported by xgrammar."):
|
||||||
|
llm.generate(prompts=[
|
||||||
|
f"Give an example JSON for an employee profile "
|
||||||
|
f"that fits this schema: {unsupported_json_schema}"
|
||||||
|
] * 2,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
use_tqdm=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_global_cleanup
|
||||||
|
@pytest.mark.parametrize("guided_decoding_backend",
|
||||||
|
GUIDED_DECODING_BACKENDS_V1)
|
||||||
|
def test_guided_grammar_ebnf(monkeypatch, sample_sql_ebnf,
|
||||||
|
guided_decoding_backend: str):
|
||||||
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
|
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||||
|
sampling_params = SamplingParams(temperature=0.8,
|
||||||
|
top_p=0.95,
|
||||||
|
max_tokens=1000,
|
||||||
|
guided_decoding=GuidedDecodingParams(
|
||||||
|
grammar=sample_sql_ebnf,
|
||||||
|
backend=guided_decoding_backend))
|
||||||
|
outputs = llm.generate(
|
||||||
|
prompts=("Generate a sql statement that selects col_1 from "
|
||||||
|
"table_1 where it is equal to 1"),
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
use_tqdm=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert outputs is not None
|
||||||
|
for output in outputs:
|
||||||
|
assert output is not None
|
||||||
|
assert isinstance(output, RequestOutput)
|
||||||
|
prompt = output.prompt
|
||||||
|
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
assert generated_text is not None
|
||||||
|
|
||||||
|
# remove spaces for comparison b/c we removed them in the grammar
|
||||||
|
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
|
||||||
|
" ", "")
|
||||||
|
|
||||||
|
assert generated_text.strip() == ground_truth
|
||||||
|
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_global_cleanup
|
||||||
|
@pytest.mark.parametrize("guided_decoding_backend",
|
||||||
|
GUIDED_DECODING_BACKENDS_V1)
|
||||||
|
def test_guided_grammar_lark(monkeypatch, sample_sql_lark,
|
||||||
|
guided_decoding_backend: str):
|
||||||
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
|
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||||
|
sampling_params = SamplingParams(temperature=0.8,
|
||||||
|
top_p=0.95,
|
||||||
|
max_tokens=1000,
|
||||||
|
guided_decoding=GuidedDecodingParams(
|
||||||
|
grammar=sample_sql_lark,
|
||||||
|
backend=guided_decoding_backend))
|
||||||
|
outputs = llm.generate(
|
||||||
|
prompts=("Generate a sql statement that selects col_1 from "
|
||||||
|
"table_1 where it is equal to 1"),
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
use_tqdm=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert outputs is not None
|
||||||
|
for output in outputs:
|
||||||
|
assert output is not None
|
||||||
|
assert isinstance(output, RequestOutput)
|
||||||
|
prompt = output.prompt
|
||||||
|
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
assert generated_text is not None
|
||||||
|
|
||||||
|
# use Lark to parse the output, and make sure it's a valid parse tree
|
||||||
|
from lark import Lark
|
||||||
|
parser = Lark(sample_sql_lark)
|
||||||
|
parser.parse(generated_text)
|
||||||
|
|
||||||
|
# remove spaces for comparison b/c we removed them in the grammar
|
||||||
|
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
|
||||||
|
" ", "")
|
||||||
|
|
||||||
|
assert generated_text.strip() == ground_truth
|
||||||
|
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_global_cleanup
|
||||||
|
@pytest.mark.parametrize("guided_decoding_backend",
|
||||||
|
GUIDED_DECODING_BACKENDS_V1)
|
||||||
|
def test_guided_grammar_ebnf_invalid(monkeypatch,
|
||||||
|
guided_decoding_backend: str):
|
||||||
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
|
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||||
|
sampling_params = SamplingParams(temperature=0.8,
|
||||||
|
top_p=0.95,
|
||||||
|
max_tokens=1000,
|
||||||
|
guided_decoding=GuidedDecodingParams(
|
||||||
|
grammar="not a grammar",
|
||||||
|
backend=guided_decoding_backend))
|
||||||
|
with pytest.raises(ValueError,
|
||||||
|
match="Failed to convert the grammar "
|
||||||
|
"from Lark to EBNF."):
|
||||||
|
llm.generate(
|
||||||
|
prompts=("Generate a sql statement that selects col_1 from "
|
||||||
|
"table_1 where it is equal to 1"),
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
use_tqdm=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_global_cleanup
|
||||||
|
@pytest.mark.parametrize("guided_decoding_backend",
|
||||||
|
GUIDED_DECODING_BACKENDS_V1)
|
||||||
|
def test_guided_regex(monkeypatch, sample_regex, guided_decoding_backend: str):
|
||||||
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
|
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||||
|
sampling_params = SamplingParams(temperature=0.8,
|
||||||
|
top_p=0.95,
|
||||||
|
guided_decoding=GuidedDecodingParams(
|
||||||
|
regex=sample_regex,
|
||||||
|
backend=guided_decoding_backend))
|
||||||
|
with pytest.raises(ValueError,
|
||||||
|
match="Regex guided decoding is not supported."):
|
||||||
|
llm.generate(prompts=[
|
||||||
|
f"Give an example IPv4 address with this regex: {sample_regex}"
|
||||||
|
] * 2,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
use_tqdm=True)
|
||||||
|
|
||||||
|
# Once regex is supported --
|
||||||
|
#assert outputs is not None
|
||||||
|
#for output in outputs:
|
||||||
|
# assert output is not None
|
||||||
|
# assert isinstance(output, RequestOutput)
|
||||||
|
# prompt = output.prompt
|
||||||
|
# generated_text = output.outputs[0].text
|
||||||
|
# print(generated_text)
|
||||||
|
# assert generated_text is not None
|
||||||
|
# assert re.fullmatch(sample_regex, generated_text) is not None
|
||||||
|
# print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip_global_cleanup
|
||||||
|
@pytest.mark.parametrize("guided_decoding_backend",
|
||||||
|
GUIDED_DECODING_BACKENDS_V1)
|
||||||
|
def test_guided_choice_completion(monkeypatch, sample_guided_choice,
|
||||||
|
guided_decoding_backend: str):
|
||||||
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
|
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||||
|
sampling_params = SamplingParams(temperature=0.8,
|
||||||
|
top_p=0.95,
|
||||||
|
guided_decoding=GuidedDecodingParams(
|
||||||
|
choice=sample_guided_choice,
|
||||||
|
backend=guided_decoding_backend))
|
||||||
|
outputs = llm.generate(
|
||||||
|
prompts="The best language for type-safe systems programming is ",
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
use_tqdm=True)
|
||||||
|
|
||||||
|
assert outputs is not None
|
||||||
|
for output in outputs:
|
||||||
|
assert output is not None
|
||||||
|
assert isinstance(output, RequestOutput)
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(generated_text)
|
||||||
|
assert generated_text is not None
|
||||||
|
assert generated_text in sample_guided_choice
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
0
tests/v1/structured_output/__init__.py
Normal file
0
tests/v1/structured_output/__init__.py
Normal file
196
tests/v1/structured_output/test_utils.py
Normal file
196
tests/v1/structured_output/test_utils.py
Normal file
@ -0,0 +1,196 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from vllm.v1.structured_output.utils import (
|
||||||
|
has_xgrammar_unsupported_json_features)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def unsupported_string_schemas():
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"pattern": "^[a-zA-Z]+$"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["active", "inactive", "pending"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"minLength": 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"maxLength": 100
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "string",
|
||||||
|
"format": "email"
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def unsupported_integer_schemas():
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "integer",
|
||||||
|
"maximum": 120
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "integer",
|
||||||
|
"exclusiveMinimum": 120
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "integer",
|
||||||
|
"exclusiveMaximum": 120
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "integer",
|
||||||
|
"multipleOf": 120
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def unsupported_number_schemas():
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"type": "number",
|
||||||
|
"minimum": 0
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number",
|
||||||
|
"maximum": 120
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number",
|
||||||
|
"exclusiveMinimum": 120
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number",
|
||||||
|
"exclusiveMaximum": 120
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "number",
|
||||||
|
"multipleOf": 120
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def unsupported_array_schemas():
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"type": "array",
|
||||||
|
"uniqueItems": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array",
|
||||||
|
"contains": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array",
|
||||||
|
"minContains": 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array",
|
||||||
|
"maxContains": 5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array",
|
||||||
|
"minItems": 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "array",
|
||||||
|
"maxItems": 10
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def unsupported_object_schemas():
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"minProperties": 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"maxProperties": 5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"propertyNames": {
|
||||||
|
"pattern": "^[a-z]+$"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"patternProperties": {
|
||||||
|
"^S": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def supported_schema():
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"age": {
|
||||||
|
"type": "integer"
|
||||||
|
},
|
||||||
|
"status": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"scores": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "number"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"address": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"street": {
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
"city": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("schema_type", [
|
||||||
|
"unsupported_string_schemas", "unsupported_integer_schemas",
|
||||||
|
"unsupported_number_schemas", "unsupported_array_schemas",
|
||||||
|
"unsupported_object_schemas"
|
||||||
|
])
|
||||||
|
def test_unsupported_json_features_by_type(schema_type, request):
|
||||||
|
schemas = request.getfixturevalue(schema_type)
|
||||||
|
for schema in schemas:
|
||||||
|
assert has_xgrammar_unsupported_json_features(
|
||||||
|
schema), f"Schema should be unsupported: {schema}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_supported_json_features(supported_schema):
|
||||||
|
assert not has_xgrammar_unsupported_json_features(
|
||||||
|
supported_schema), "Schema should be supported"
|
@ -72,6 +72,8 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
|||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=0,
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_input_ids=[],
|
free_encoder_input_ids=[],
|
||||||
|
structured_output_request_ids={},
|
||||||
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -135,6 +137,8 @@ def test_update_states_request_finished(model_runner):
|
|||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=0,
|
||||||
finished_req_ids={req_id},
|
finished_req_ids={req_id},
|
||||||
free_encoder_input_ids=[],
|
free_encoder_input_ids=[],
|
||||||
|
structured_output_request_ids={},
|
||||||
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata_before = model_runner.input_batch.sampling_metadata
|
metadata_before = model_runner.input_batch.sampling_metadata
|
||||||
@ -165,6 +169,8 @@ def test_update_states_request_resumed(model_runner):
|
|||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=0,
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_input_ids=[],
|
free_encoder_input_ids=[],
|
||||||
|
structured_output_request_ids={},
|
||||||
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_runner._update_states(scheduler_output)
|
model_runner._update_states(scheduler_output)
|
||||||
@ -190,6 +196,8 @@ def test_update_states_request_resumed(model_runner):
|
|||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=0,
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_input_ids=[],
|
free_encoder_input_ids=[],
|
||||||
|
structured_output_request_ids={},
|
||||||
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata_before = model_runner.input_batch.sampling_metadata
|
metadata_before = model_runner.input_batch.sampling_metadata
|
||||||
@ -221,6 +229,8 @@ def test_update_states_no_changes(model_runner):
|
|||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=0,
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_input_ids=[],
|
free_encoder_input_ids=[],
|
||||||
|
structured_output_request_ids={},
|
||||||
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata_before = model_runner.input_batch.sampling_metadata
|
metadata_before = model_runner.input_batch.sampling_metadata
|
||||||
@ -256,6 +266,8 @@ def test_update_states_request_unscheduled(model_runner):
|
|||||||
num_common_prefix_blocks=0,
|
num_common_prefix_blocks=0,
|
||||||
finished_req_ids=set(),
|
finished_req_ids=set(),
|
||||||
free_encoder_input_ids=[],
|
free_encoder_input_ids=[],
|
||||||
|
structured_output_request_ids={},
|
||||||
|
grammar_bitmask=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata_before = model_runner._update_states(scheduler_output)
|
metadata_before = model_runner._update_states(scheduler_output)
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import concurrent
|
import concurrent
|
||||||
@ -8,6 +10,7 @@ import datetime
|
|||||||
import enum
|
import enum
|
||||||
import gc
|
import gc
|
||||||
import getpass
|
import getpass
|
||||||
|
import importlib
|
||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import inspect
|
import inspect
|
||||||
@ -23,6 +26,7 @@ import tempfile
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
import types
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
import weakref
|
import weakref
|
||||||
@ -982,7 +986,7 @@ def current_stream() -> torch.cuda.Stream:
|
|||||||
return _current_stream
|
return _current_stream
|
||||||
|
|
||||||
|
|
||||||
def enable_trace_function_call_for_thread(vllm_config: "VllmConfig") -> None:
|
def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None:
|
||||||
"""Set up function tracing for the current thread,
|
"""Set up function tracing for the current thread,
|
||||||
if enabled via the VLLM_TRACE_FUNCTION environment variable
|
if enabled via the VLLM_TRACE_FUNCTION environment variable
|
||||||
"""
|
"""
|
||||||
@ -1977,7 +1981,7 @@ class MemorySnapshot:
|
|||||||
self.non_torch_memory = self.cuda_memory - self.torch_memory
|
self.non_torch_memory = self.cuda_memory - self.torch_memory
|
||||||
self.timestamp = time.time()
|
self.timestamp = time.time()
|
||||||
|
|
||||||
def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
|
def __sub__(self, other: MemorySnapshot) -> MemorySnapshot:
|
||||||
return MemorySnapshot(
|
return MemorySnapshot(
|
||||||
torch_peak=self.torch_peak - other.torch_peak,
|
torch_peak=self.torch_peak - other.torch_peak,
|
||||||
cuda_memory=self.cuda_memory - other.cuda_memory,
|
cuda_memory=self.cuda_memory - other.cuda_memory,
|
||||||
@ -2306,3 +2310,54 @@ def warn_for_unimplemented_methods(cls: type[T]) -> type[T]:
|
|||||||
|
|
||||||
type.__setattr__(cls, '__init__', wrapped_init)
|
type.__setattr__(cls, '__init__', wrapped_init)
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
class LazyLoader(types.ModuleType):
|
||||||
|
"""
|
||||||
|
LazyLoader module borrowed from Tensorflow
|
||||||
|
https://github.com/tensorflow/tensorflow/blob/main/tensorflow/python/util/lazy_loader.py
|
||||||
|
with a addition of "module caching".
|
||||||
|
|
||||||
|
Lazily import a module, mainly to avoid pulling in large dependencies.
|
||||||
|
Modules such as `xgrammar` might do additional side effects, so we
|
||||||
|
only want to use this when it is needed, delaying all eager effects
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
local_name: str,
|
||||||
|
parent_module_globals: dict[str, Any],
|
||||||
|
name: str,
|
||||||
|
):
|
||||||
|
self._local_name = local_name
|
||||||
|
self._parent_module_globals = parent_module_globals
|
||||||
|
self._module: types.ModuleType | None = None
|
||||||
|
|
||||||
|
super().__init__(str(name))
|
||||||
|
|
||||||
|
def _load(self) -> types.ModuleType:
|
||||||
|
# Import the target module and insert it into the parent's namespace
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(self.__name__)
|
||||||
|
self._parent_module_globals[self._local_name] = module
|
||||||
|
# The additional add to sys.modules
|
||||||
|
# ensures library is actually loaded.
|
||||||
|
sys.modules[self._local_name] = module
|
||||||
|
except ModuleNotFoundError as err:
|
||||||
|
raise err from None
|
||||||
|
|
||||||
|
# Update this object's dict so that if someone keeps a
|
||||||
|
# reference to the LazyLoader, lookups are efficient
|
||||||
|
# (__getattr__ is only called on lookups that fail).
|
||||||
|
self.__dict__.update(module.__dict__)
|
||||||
|
return module
|
||||||
|
|
||||||
|
def __getattr__(self, item: Any) -> Any:
|
||||||
|
if self._module is None:
|
||||||
|
self._module = self._load()
|
||||||
|
return getattr(self._module, item)
|
||||||
|
|
||||||
|
def __dir__(self) -> list[str]:
|
||||||
|
if self._module is None:
|
||||||
|
self._module = self._load()
|
||||||
|
return dir(self._module)
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
@ -18,6 +20,7 @@ from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
|
|||||||
from vllm.v1.metrics.stats import SchedulerStats
|
from vllm.v1.metrics.stats import SchedulerStats
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
from vllm.v1.request import Request, RequestStatus
|
from vllm.v1.request import Request, RequestStatus
|
||||||
|
from vllm.v1.structured_output import StructuredOutputManager
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -32,12 +35,14 @@ class Scheduler:
|
|||||||
lora_config: Optional[LoRAConfig],
|
lora_config: Optional[LoRAConfig],
|
||||||
speculative_config: Optional[SpeculativeConfig],
|
speculative_config: Optional[SpeculativeConfig],
|
||||||
log_stats: bool,
|
log_stats: bool,
|
||||||
|
structured_output_manager: StructuredOutputManager,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
self.cache_config = cache_config
|
self.cache_config = cache_config
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
self.speculative_config = speculative_config
|
self.speculative_config = speculative_config
|
||||||
self.log_stats = log_stats
|
self.log_stats = log_stats
|
||||||
|
self.structured_output_manager = structured_output_manager
|
||||||
|
|
||||||
# Scheduling constraints.
|
# Scheduling constraints.
|
||||||
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
|
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
|
||||||
@ -97,7 +102,7 @@ class Scheduler:
|
|||||||
self.encoder_cache_manager = EncoderCacheManager(
|
self.encoder_cache_manager = EncoderCacheManager(
|
||||||
cache_size=encoder_cache_size)
|
cache_size=encoder_cache_size)
|
||||||
|
|
||||||
def schedule(self) -> "SchedulerOutput":
|
def schedule(self) -> SchedulerOutput:
|
||||||
# NOTE(woosuk) on the scheduling algorithm:
|
# NOTE(woosuk) on the scheduling algorithm:
|
||||||
# There's no "decoding phase" nor "prefill phase" in the scheduler.
|
# There's no "decoding phase" nor "prefill phase" in the scheduler.
|
||||||
# Each request just has the num_computed_tokens and
|
# Each request just has the num_computed_tokens and
|
||||||
@ -114,6 +119,14 @@ class Scheduler:
|
|||||||
scheduled_running_reqs: list[Request] = []
|
scheduled_running_reqs: list[Request] = []
|
||||||
preempted_reqs: list[Request] = []
|
preempted_reqs: list[Request] = []
|
||||||
|
|
||||||
|
# NOTE: structured_output_request_ids maps
|
||||||
|
# a request's (request that uses structured output)
|
||||||
|
# request_id to the running request index.
|
||||||
|
# This will helps us determine to slice the grammar bitmask
|
||||||
|
# and only applies valid mask for requests that
|
||||||
|
# uses structured decoding.
|
||||||
|
structured_output_request_ids: dict[str, int] = {}
|
||||||
|
|
||||||
req_to_new_block_ids: dict[str, list[int]] = {}
|
req_to_new_block_ids: dict[str, list[int]] = {}
|
||||||
num_scheduled_tokens: dict[str, int] = {}
|
num_scheduled_tokens: dict[str, int] = {}
|
||||||
token_budget = self.max_num_scheduled_tokens
|
token_budget = self.max_num_scheduled_tokens
|
||||||
@ -184,6 +197,12 @@ class Scheduler:
|
|||||||
# Schedule the request.
|
# Schedule the request.
|
||||||
scheduled_running_reqs.append(request)
|
scheduled_running_reqs.append(request)
|
||||||
self.scheduled_req_ids.add(request.request_id)
|
self.scheduled_req_ids.add(request.request_id)
|
||||||
|
if request.use_structured_output:
|
||||||
|
# PERF: in case of chunked prefill,
|
||||||
|
# request might not include any new tokens.
|
||||||
|
# Therefore, we might introduce some additional
|
||||||
|
# cycle to fill in the bitmask, which could be a big no-op.
|
||||||
|
structured_output_request_ids[request.request_id] = req_index
|
||||||
req_to_new_block_ids[request.request_id] = [
|
req_to_new_block_ids[request.request_id] = [
|
||||||
b.block_id for b in new_blocks
|
b.block_id for b in new_blocks
|
||||||
]
|
]
|
||||||
@ -219,6 +238,10 @@ class Scheduler:
|
|||||||
if req.lora_request and req.lora_request.lora_int_id > 0)
|
if req.lora_request and req.lora_request.lora_int_id > 0)
|
||||||
assert len(requested_loras) <= self.lora_config.max_loras
|
assert len(requested_loras) <= self.lora_config.max_loras
|
||||||
|
|
||||||
|
# Use a temporary deque to collect requests that need to be skipped
|
||||||
|
# and put back at the head of the waiting queue later
|
||||||
|
waiting_for_fsm: deque[Request] = deque()
|
||||||
|
|
||||||
# Next, schedule the WAITING requests.
|
# Next, schedule the WAITING requests.
|
||||||
if not preempted_reqs:
|
if not preempted_reqs:
|
||||||
while self.waiting and token_budget > 0:
|
while self.waiting and token_budget > 0:
|
||||||
@ -227,6 +250,16 @@ class Scheduler:
|
|||||||
|
|
||||||
request = self.waiting[0]
|
request = self.waiting[0]
|
||||||
|
|
||||||
|
if request.status == RequestStatus.WAITING_FOR_FSM:
|
||||||
|
structured_output_req = request.structured_output_request
|
||||||
|
if structured_output_req and structured_output_req.grammar:
|
||||||
|
request.status = RequestStatus.WAITING
|
||||||
|
else:
|
||||||
|
waiting_structured_output_req = self.waiting.popleft()
|
||||||
|
waiting_for_fsm.appendleft(
|
||||||
|
waiting_structured_output_req)
|
||||||
|
continue
|
||||||
|
|
||||||
# Check that adding the request still respects the max_loras
|
# Check that adding the request still respects the max_loras
|
||||||
# constraint.
|
# constraint.
|
||||||
if self.lora_config and request.lora_request:
|
if self.lora_config and request.lora_request:
|
||||||
@ -281,6 +314,10 @@ class Scheduler:
|
|||||||
break
|
break
|
||||||
|
|
||||||
self.waiting.popleft()
|
self.waiting.popleft()
|
||||||
|
if request.use_structured_output:
|
||||||
|
structured_output_request_ids[
|
||||||
|
request.request_id] = req_index
|
||||||
|
req_index += 1
|
||||||
self.running.append(request)
|
self.running.append(request)
|
||||||
self.scheduled_req_ids.add(request.request_id)
|
self.scheduled_req_ids.add(request.request_id)
|
||||||
self.request_scheduled(request, scheduled_timestamp)
|
self.request_scheduled(request, scheduled_timestamp)
|
||||||
@ -311,6 +348,10 @@ class Scheduler:
|
|||||||
self.encoder_cache_manager.allocate(request, i)
|
self.encoder_cache_manager.allocate(request, i)
|
||||||
encoder_budget = new_encoder_budget
|
encoder_budget = new_encoder_budget
|
||||||
|
|
||||||
|
# Put back any skipped requests at the head of the waiting queue
|
||||||
|
if waiting_for_fsm:
|
||||||
|
self.waiting.extendleft(waiting_for_fsm)
|
||||||
|
|
||||||
# Check if the scheduling constraints are satisfied.
|
# Check if the scheduling constraints are satisfied.
|
||||||
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
||||||
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
|
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
|
||||||
@ -331,6 +372,11 @@ class Scheduler:
|
|||||||
self.kv_cache_manager.get_num_common_prefix_blocks(
|
self.kv_cache_manager.get_num_common_prefix_blocks(
|
||||||
any_request, len(self.running)))
|
any_request, len(self.running)))
|
||||||
|
|
||||||
|
grammar_bitmask = self.structured_output_manager.grammar_bitmask(
|
||||||
|
self.requests,
|
||||||
|
structured_output_request_ids,
|
||||||
|
len(self.running),
|
||||||
|
)
|
||||||
# Construct the scheduler output.
|
# Construct the scheduler output.
|
||||||
new_reqs_data = [
|
new_reqs_data = [
|
||||||
NewRequestData.from_request(req,
|
NewRequestData.from_request(req,
|
||||||
@ -369,6 +415,8 @@ class Scheduler:
|
|||||||
# the previous and the current steps.
|
# the previous and the current steps.
|
||||||
finished_req_ids=self.finished_req_ids,
|
finished_req_ids=self.finished_req_ids,
|
||||||
free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(),
|
free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(),
|
||||||
|
structured_output_request_ids=structured_output_request_ids,
|
||||||
|
grammar_bitmask=grammar_bitmask,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.finished_req_ids = set()
|
self.finished_req_ids = set()
|
||||||
@ -381,7 +429,7 @@ class Scheduler:
|
|||||||
num_scheduled_spec_tokens: int,
|
num_scheduled_spec_tokens: int,
|
||||||
new_block_ids: list[int],
|
new_block_ids: list[int],
|
||||||
resumed_from_preemption: bool,
|
resumed_from_preemption: bool,
|
||||||
) -> "CachedRequestData":
|
) -> CachedRequestData:
|
||||||
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
|
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
|
||||||
# them at each scheduling step.
|
# them at each scheduling step.
|
||||||
num_computed_tokens = request.num_computed_tokens
|
num_computed_tokens = request.num_computed_tokens
|
||||||
@ -474,8 +522,8 @@ class Scheduler:
|
|||||||
|
|
||||||
def update_from_output(
|
def update_from_output(
|
||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: SchedulerOutput,
|
||||||
model_runner_output: "ModelRunnerOutput",
|
model_runner_output: ModelRunnerOutput,
|
||||||
) -> EngineCoreOutputs:
|
) -> EngineCoreOutputs:
|
||||||
sampled_token_ids = model_runner_output.sampled_token_ids
|
sampled_token_ids = model_runner_output.sampled_token_ids
|
||||||
spec_token_ids = model_runner_output.spec_token_ids
|
spec_token_ids = model_runner_output.spec_token_ids
|
||||||
@ -565,6 +613,15 @@ class Scheduler:
|
|||||||
# the outer lists can be of length > 1.
|
# the outer lists can be of length > 1.
|
||||||
new_logprobs = logprobs.slice(req_index, req_index + 1)
|
new_logprobs = logprobs.slice(req_index, req_index + 1)
|
||||||
|
|
||||||
|
if new_token_ids and request.use_structured_output:
|
||||||
|
# NOTE: structured_output_request
|
||||||
|
# should not be None if use_structured_output, we have
|
||||||
|
# check above, so safe to ignore type warning
|
||||||
|
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
|
||||||
|
request.request_id,
|
||||||
|
new_token_ids,
|
||||||
|
)
|
||||||
|
|
||||||
# Transmit partial if chunked prefill & prompt logprobs is enabled
|
# Transmit partial if chunked prefill & prompt logprobs is enabled
|
||||||
if new_token_ids or prompt_logprobs_tensors is not None:
|
if new_token_ids or prompt_logprobs_tensors is not None:
|
||||||
# Add EngineCoreOutput for this Request.
|
# Add EngineCoreOutput for this Request.
|
||||||
|
@ -1,9 +1,14 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.multimodal import MultiModalKwargs
|
from vllm.multimodal import MultiModalKwargs
|
||||||
from vllm.multimodal.base import PlaceholderRange
|
from vllm.multimodal.base import PlaceholderRange
|
||||||
@ -17,20 +22,20 @@ class NewRequestData:
|
|||||||
req_id: str
|
req_id: str
|
||||||
prompt_token_ids: list[int]
|
prompt_token_ids: list[int]
|
||||||
prompt: Optional[str]
|
prompt: Optional[str]
|
||||||
mm_inputs: list["MultiModalKwargs"]
|
mm_inputs: list[MultiModalKwargs]
|
||||||
mm_hashes: list[str]
|
mm_hashes: list[str]
|
||||||
mm_positions: list["PlaceholderRange"]
|
mm_positions: list[PlaceholderRange]
|
||||||
sampling_params: "SamplingParams"
|
sampling_params: SamplingParams
|
||||||
block_ids: list[int]
|
block_ids: list[int]
|
||||||
num_computed_tokens: int
|
num_computed_tokens: int
|
||||||
lora_request: Optional["LoRARequest"]
|
lora_request: Optional[LoRARequest]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_request(
|
def from_request(
|
||||||
cls,
|
cls,
|
||||||
request: "Request",
|
request: Request,
|
||||||
block_ids: list[int],
|
block_ids: list[int],
|
||||||
) -> "NewRequestData":
|
) -> NewRequestData:
|
||||||
return cls(
|
return cls(
|
||||||
req_id=request.request_id,
|
req_id=request.request_id,
|
||||||
prompt_token_ids=request.prompt_token_ids,
|
prompt_token_ids=request.prompt_token_ids,
|
||||||
@ -60,11 +65,11 @@ class CachedRequestData:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_request(
|
def from_request(
|
||||||
cls,
|
cls,
|
||||||
request: "Request",
|
request: Request,
|
||||||
resumed_from_preemption: bool,
|
resumed_from_preemption: bool,
|
||||||
new_token_ids: list[int],
|
new_token_ids: list[int],
|
||||||
new_block_ids: list[int],
|
new_block_ids: list[int],
|
||||||
) -> "CachedRequestData":
|
) -> CachedRequestData:
|
||||||
return cls(
|
return cls(
|
||||||
req_id=request.request_id,
|
req_id=request.request_id,
|
||||||
resumed_from_preemption=resumed_from_preemption,
|
resumed_from_preemption=resumed_from_preemption,
|
||||||
@ -111,3 +116,9 @@ class SchedulerOutput:
|
|||||||
# list of (req_id, encoder_input_index) tuples.
|
# list of (req_id, encoder_input_index) tuples.
|
||||||
# Used to free the encoder cache.
|
# Used to free the encoder cache.
|
||||||
free_encoder_input_ids: list[tuple[str, int]]
|
free_encoder_input_ids: list[tuple[str, int]]
|
||||||
|
|
||||||
|
# Dict of request ids to their index within the batch
|
||||||
|
# for filling the next token bitmask
|
||||||
|
structured_output_request_ids: dict[str, int]
|
||||||
|
# the bitmask for the whole batch
|
||||||
|
grammar_bitmask: Optional[npt.NDArray[np.int32]]
|
||||||
|
@ -72,9 +72,7 @@ class AsyncLLM(EngineClient):
|
|||||||
|
|
||||||
# Processor (converts Inputs --> EngineCoreRequests).
|
# Processor (converts Inputs --> EngineCoreRequests).
|
||||||
self.processor = Processor(
|
self.processor = Processor(
|
||||||
model_config=vllm_config.model_config,
|
vllm_config=vllm_config,
|
||||||
cache_config=vllm_config.cache_config,
|
|
||||||
lora_config=vllm_config.lora_config,
|
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
input_registry=input_registry,
|
input_registry=input_registry,
|
||||||
)
|
)
|
||||||
@ -194,8 +192,8 @@ class AsyncLLM(EngineClient):
|
|||||||
* 3) Adding the Request to the Detokenizer.
|
* 3) Adding the Request to the Detokenizer.
|
||||||
* 4) Adding the Request to the EngineCore (separate process).
|
* 4) Adding the Request to the EngineCore (separate process).
|
||||||
|
|
||||||
A separate output_handler loop runs in a background AsyncIO task,
|
A separate output_handler loop runs in a background AsyncIO task,
|
||||||
pulling outputs from EngineCore and putting them into the
|
pulling outputs from EngineCore and putting them into the
|
||||||
per-request AsyncStream.
|
per-request AsyncStream.
|
||||||
|
|
||||||
The caller of generate() iterates the returned AsyncGenerator,
|
The caller of generate() iterates the returned AsyncGenerator,
|
||||||
|
@ -29,6 +29,7 @@ from vllm.v1.executor.abstract import Executor
|
|||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
from vllm.v1.request import Request, RequestStatus
|
from vllm.v1.request import Request, RequestStatus
|
||||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||||
|
from vllm.v1.structured_output import StructuredOutputManager
|
||||||
from vllm.version import __version__ as VLLM_VERSION
|
from vllm.version import __version__ as VLLM_VERSION
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -61,6 +62,8 @@ class EngineCore:
|
|||||||
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
|
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||||
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
|
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||||
|
|
||||||
|
self.structured_output_manager = StructuredOutputManager(vllm_config)
|
||||||
|
|
||||||
# Setup scheduler.
|
# Setup scheduler.
|
||||||
self.scheduler = Scheduler(
|
self.scheduler = Scheduler(
|
||||||
scheduler_config=vllm_config.scheduler_config,
|
scheduler_config=vllm_config.scheduler_config,
|
||||||
@ -69,6 +72,7 @@ class EngineCore:
|
|||||||
lora_config=vllm_config.lora_config,
|
lora_config=vllm_config.lora_config,
|
||||||
speculative_config=vllm_config.speculative_config,
|
speculative_config=vllm_config.speculative_config,
|
||||||
log_stats=self.log_stats,
|
log_stats=self.log_stats,
|
||||||
|
structured_output_manager=self.structured_output_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Setup MM Input Mapper.
|
# Setup MM Input Mapper.
|
||||||
@ -131,6 +135,9 @@ class EngineCore:
|
|||||||
request.mm_inputs, request.mm_hashes)
|
request.mm_inputs, request.mm_hashes)
|
||||||
|
|
||||||
req = Request.from_engine_core_request(request)
|
req = Request.from_engine_core_request(request)
|
||||||
|
if req.use_structured_output:
|
||||||
|
# Start grammar compilation asynchronously
|
||||||
|
self.structured_output_manager.populate_cache(req)
|
||||||
|
|
||||||
self.scheduler.add_request(req)
|
self.scheduler.add_request(req)
|
||||||
|
|
||||||
@ -148,11 +155,24 @@ class EngineCore:
|
|||||||
|
|
||||||
if not self.scheduler.has_unfinished_requests():
|
if not self.scheduler.has_unfinished_requests():
|
||||||
return EngineCoreOutputs(
|
return EngineCoreOutputs(
|
||||||
outputs=[], scheduler_stats=self.scheduler.make_stats())
|
outputs=[],
|
||||||
|
scheduler_stats=self.scheduler.make_stats(),
|
||||||
|
)
|
||||||
scheduler_output = self.scheduler.schedule()
|
scheduler_output = self.scheduler.schedule()
|
||||||
|
|
||||||
|
# This case may occur when the only unfinished requests are
|
||||||
|
# structured output requests where the grammar has not finished
|
||||||
|
# compiling yet, so there's nothing to run.
|
||||||
|
if scheduler_output.total_num_scheduled_tokens == 0:
|
||||||
|
return EngineCoreOutputs(
|
||||||
|
outputs=[],
|
||||||
|
scheduler_stats=self.scheduler.make_stats(),
|
||||||
|
)
|
||||||
|
|
||||||
output = self.model_executor.execute_model(scheduler_output)
|
output = self.model_executor.execute_model(scheduler_output)
|
||||||
engine_core_outputs = self.scheduler.update_from_output(
|
engine_core_outputs = self.scheduler.update_from_output(
|
||||||
scheduler_output, output) # type: ignore
|
scheduler_output, output) # type: ignore
|
||||||
|
|
||||||
return engine_core_outputs
|
return engine_core_outputs
|
||||||
|
|
||||||
def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]:
|
def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]:
|
||||||
|
@ -66,9 +66,7 @@ class LLMEngine:
|
|||||||
self.tokenizer.ping()
|
self.tokenizer.ping()
|
||||||
|
|
||||||
# Processor (convert Inputs --> EngineCoreRequests)
|
# Processor (convert Inputs --> EngineCoreRequests)
|
||||||
self.processor = Processor(model_config=vllm_config.model_config,
|
self.processor = Processor(vllm_config=vllm_config,
|
||||||
cache_config=vllm_config.cache_config,
|
|
||||||
lora_config=vllm_config.lora_config,
|
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
input_registry=input_registry,
|
input_registry=input_registry,
|
||||||
mm_registry=mm_registry)
|
mm_registry=mm_registry)
|
||||||
|
@ -4,7 +4,7 @@ import time
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from vllm.config import CacheConfig, LoRAConfig, ModelConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
||||||
PromptType, SingletonInputsAdapter)
|
PromptType, SingletonInputsAdapter)
|
||||||
from vllm.inputs.parse import is_encoder_decoder_inputs
|
from vllm.inputs.parse import is_encoder_decoder_inputs
|
||||||
@ -19,39 +19,41 @@ from vllm.sampling_params import SamplingParams
|
|||||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||||
from vllm.v1.engine import EngineCoreRequest
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
from vllm.v1.engine.mm_input_cache import MMInputCacheClient
|
from vllm.v1.engine.mm_input_cache import MMInputCacheClient
|
||||||
|
from vllm.v1.structured_output.utils import validate_structured_output_request
|
||||||
|
|
||||||
|
|
||||||
class Processor:
|
class Processor:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_config: ModelConfig,
|
vllm_config: VllmConfig,
|
||||||
cache_config: CacheConfig,
|
|
||||||
lora_config: Optional[LoRAConfig],
|
|
||||||
tokenizer: BaseTokenizerGroup,
|
tokenizer: BaseTokenizerGroup,
|
||||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||||
):
|
):
|
||||||
|
|
||||||
self.model_config = model_config
|
self.vllm_config = vllm_config
|
||||||
self.cache_config = cache_config
|
self.model_config = vllm_config.model_config
|
||||||
self.lora_config = lora_config
|
self.cache_config = vllm_config.cache_config
|
||||||
|
self.lora_config = vllm_config.lora_config
|
||||||
|
self.decoding_config = vllm_config.decoding_config
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
self.generation_config_fields = model_config.try_get_generation_config(
|
self.generation_config_fields = (
|
||||||
)
|
self.model_config.try_get_generation_config())
|
||||||
self.input_preprocessor = InputPreprocessor(model_config,
|
self.input_preprocessor = InputPreprocessor(self.model_config,
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
mm_registry)
|
mm_registry)
|
||||||
self.input_processor = input_registry.create_input_processor(
|
self.input_processor = input_registry.create_input_processor(
|
||||||
model_config)
|
self.model_config)
|
||||||
|
|
||||||
# Multi-modal (huggingface) input mapper
|
# Multi-modal (huggingface) input mapper
|
||||||
self.mm_input_cache_client = MMInputCacheClient(model_config)
|
self.mm_input_cache_client = MMInputCacheClient(self.model_config)
|
||||||
|
|
||||||
# Multi-modal hasher (for images)
|
# Multi-modal hasher (for images)
|
||||||
self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
|
self.use_hash = (
|
||||||
cache_config.enable_prefix_caching
|
not self.model_config.disable_mm_preprocessor_cache) or \
|
||||||
|
self.cache_config.enable_prefix_caching
|
||||||
|
|
||||||
def _validate_logprobs(
|
def _validate_logprobs(
|
||||||
self,
|
self,
|
||||||
@ -80,6 +82,8 @@ class Processor:
|
|||||||
self,
|
self,
|
||||||
params: SamplingParams,
|
params: SamplingParams,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
self._validate_structured_output(params)
|
||||||
|
|
||||||
if params.allowed_token_ids is None:
|
if params.allowed_token_ids is None:
|
||||||
return
|
return
|
||||||
if not params.allowed_token_ids:
|
if not params.allowed_token_ids:
|
||||||
@ -125,6 +129,21 @@ class Processor:
|
|||||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||||
"not enabled!")
|
"not enabled!")
|
||||||
|
|
||||||
|
def _validate_structured_output(self, params: SamplingParams) -> None:
|
||||||
|
if not params.guided_decoding or not self.decoding_config:
|
||||||
|
return
|
||||||
|
if self.decoding_config.guided_decoding_backend != "xgrammar":
|
||||||
|
raise ValueError(
|
||||||
|
"Only xgrammar structured output is supported in V1.")
|
||||||
|
if (params.guided_decoding.backend
|
||||||
|
and params.guided_decoding.backend != 'xgrammar'):
|
||||||
|
raise ValueError(
|
||||||
|
"Only xgrammar structured output is supported in V1.")
|
||||||
|
if self.vllm_config.speculative_config:
|
||||||
|
raise ValueError("Structured output is not supported with "
|
||||||
|
"speculative decoding.")
|
||||||
|
validate_structured_output_request(params)
|
||||||
|
|
||||||
def process_inputs(
|
def process_inputs(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
|
@ -3,13 +3,15 @@
|
|||||||
import enum
|
import enum
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
from vllm.lora.request import LoRARequest
|
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
|
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
|
||||||
EngineCoreRequest, FinishReason)
|
EngineCoreRequest, FinishReason)
|
||||||
|
from vllm.v1.structured_output.request import StructuredOutputRequest
|
||||||
from vllm.v1.utils import ConstantList
|
from vllm.v1.utils import ConstantList
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.multimodal import MultiModalKwargs
|
from vllm.multimodal import MultiModalKwargs
|
||||||
from vllm.multimodal.inputs import PlaceholderRange
|
from vllm.multimodal.inputs import PlaceholderRange
|
||||||
|
|
||||||
@ -27,15 +29,19 @@ class Request:
|
|||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
eos_token_id: Optional[int],
|
eos_token_id: Optional[int],
|
||||||
arrival_time: float,
|
arrival_time: float,
|
||||||
lora_request: Optional[LoRARequest] = None,
|
lora_request: Optional["LoRARequest"] = None,
|
||||||
|
structured_output_request: Optional["StructuredOutputRequest"] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self.sampling_params = sampling_params
|
self.sampling_params = sampling_params
|
||||||
# Because of LoRA, the eos token id can be different for each request.
|
# Because of LoRA, the eos token id can be different for each request.
|
||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
self.lora_request = lora_request
|
self.lora_request = lora_request
|
||||||
|
self.structured_output_request = structured_output_request
|
||||||
|
|
||||||
self.status = RequestStatus.WAITING
|
self.status = (RequestStatus.WAITING_FOR_FSM
|
||||||
|
if sampling_params.guided_decoding is not None else
|
||||||
|
RequestStatus.WAITING)
|
||||||
self.events: list[EngineCoreEvent] = []
|
self.events: list[EngineCoreEvent] = []
|
||||||
self.stop_reason: Union[int, str, None] = None
|
self.stop_reason: Union[int, str, None] = None
|
||||||
assert sampling_params.max_tokens is not None
|
assert sampling_params.max_tokens is not None
|
||||||
@ -78,6 +84,8 @@ class Request:
|
|||||||
eos_token_id=request.eos_token_id,
|
eos_token_id=request.eos_token_id,
|
||||||
arrival_time=request.arrival_time,
|
arrival_time=request.arrival_time,
|
||||||
lora_request=request.lora_request,
|
lora_request=request.lora_request,
|
||||||
|
structured_output_request=StructuredOutputRequest(
|
||||||
|
sampling_params=request.sampling_params),
|
||||||
)
|
)
|
||||||
|
|
||||||
def queued(self, timestamp: Optional[float] = None) -> None:
|
def queued(self, timestamp: Optional[float] = None) -> None:
|
||||||
@ -134,18 +142,23 @@ class Request:
|
|||||||
num_tokens = self.mm_positions[input_id]["length"]
|
num_tokens = self.mm_positions[input_id]["length"]
|
||||||
return num_tokens
|
return num_tokens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_structured_output(self) -> bool:
|
||||||
|
return self.sampling_params.guided_decoding is not None
|
||||||
|
|
||||||
|
|
||||||
class RequestStatus(enum.IntEnum):
|
class RequestStatus(enum.IntEnum):
|
||||||
"""Status of a request."""
|
"""Status of a request."""
|
||||||
WAITING = 0
|
WAITING = enum.auto()
|
||||||
RUNNING = 1
|
WAITING_FOR_FSM = enum.auto()
|
||||||
PREEMPTED = 2
|
RUNNING = enum.auto()
|
||||||
# Note: anything after PREEMPTED (2) will be considered
|
PREEMPTED = enum.auto()
|
||||||
|
# Note: anything after PREEMPTED will be considered
|
||||||
# as a finished status.
|
# as a finished status.
|
||||||
FINISHED_STOPPED = 3
|
FINISHED_STOPPED = enum.auto()
|
||||||
FINISHED_LENGTH_CAPPED = 4
|
FINISHED_LENGTH_CAPPED = enum.auto()
|
||||||
FINISHED_ABORTED = 5
|
FINISHED_ABORTED = enum.auto()
|
||||||
FINISHED_IGNORED = 6
|
FINISHED_IGNORED = enum.auto()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_finished(status: "RequestStatus") -> bool:
|
def is_finished(status: "RequestStatus") -> bool:
|
||||||
|
152
vllm/v1/structured_output/__init__.py
Normal file
152
vllm/v1/structured_output/__init__.py
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import multiprocessing
|
||||||
|
from collections import OrderedDict
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||||
|
from vllm.utils import LazyLoader
|
||||||
|
from vllm.v1.structured_output.grammar import (Grammar, StructuredOutputKey,
|
||||||
|
StructuredOutputOptions)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
import xgrammar as xgr
|
||||||
|
|
||||||
|
from vllm.v1.request import Request
|
||||||
|
else:
|
||||||
|
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class StructuredOutputManager:
|
||||||
|
|
||||||
|
def __init__(self, vllm_config: VllmConfig, max_cache_size: int = 500):
|
||||||
|
tokenizer_group = init_tokenizer_from_configs(
|
||||||
|
model_config=vllm_config.model_config,
|
||||||
|
scheduler_config=vllm_config.scheduler_config,
|
||||||
|
parallel_config=vllm_config.parallel_config,
|
||||||
|
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
|
||||||
|
tokenizer_group.ping()
|
||||||
|
self.vocab_size = vllm_config.model_config.get_vocab_size()
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
|
||||||
|
tokenizer = tokenizer_group.get_lora_tokenizer(None)
|
||||||
|
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
|
||||||
|
tokenizer, vocab_size=self.vocab_size)
|
||||||
|
self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8)
|
||||||
|
|
||||||
|
self.max_cache_size = max_cache_size
|
||||||
|
self.request_key_to_grammar: OrderedDict[StructuredOutputKey,
|
||||||
|
Grammar] = OrderedDict()
|
||||||
|
|
||||||
|
# The default max_workers if not specified is the number of CPUs * 5,
|
||||||
|
# which is way too high since these tasks are CPU-bound, not I/O bound.
|
||||||
|
# We also know we would never dominate CPU usage with just grammar
|
||||||
|
# compilation, so we set it to half the number of CPUs.
|
||||||
|
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
|
||||||
|
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||||
|
self._grammar_bitmask = xgr.allocate_token_bitmask(
|
||||||
|
self.vllm_config.scheduler_config.max_num_seqs, self.vocab_size)
|
||||||
|
|
||||||
|
def __getitem__(self, key: StructuredOutputKey) -> Optional[Grammar]:
|
||||||
|
# We need to pop and re-insert the grammar here for LRU cache
|
||||||
|
# of request_key_to_grammar
|
||||||
|
if key in self.request_key_to_grammar:
|
||||||
|
# Move accessed item to the end (most recently used)
|
||||||
|
value = self.request_key_to_grammar.pop(key)
|
||||||
|
if value is not None:
|
||||||
|
self.request_key_to_grammar[key] = value
|
||||||
|
return value
|
||||||
|
return None
|
||||||
|
|
||||||
|
def populate_cache(self, request: Request) -> None:
|
||||||
|
if request.structured_output_request is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
grammar = self.request_key_to_grammar.get(
|
||||||
|
request.structured_output_request.structured_output_key)
|
||||||
|
if grammar:
|
||||||
|
request.structured_output_request.grammar = copy.copy(grammar)
|
||||||
|
return
|
||||||
|
request.structured_output_request.grammar = self.cache(request)
|
||||||
|
|
||||||
|
def cache(self, request: Request):
|
||||||
|
return self.executor.submit(self._executor_loop, request)
|
||||||
|
|
||||||
|
def _executor_loop(self, request: Request) -> Grammar:
|
||||||
|
# NOTE: The structured_output_request should never be
|
||||||
|
# None in this case, but mypy can't infer this
|
||||||
|
# correctly, so we need to ignore the error here.
|
||||||
|
key = request.structured_output_request.structured_output_key # type: ignore[union-attr]
|
||||||
|
grammar = self.request_key_to_grammar.get(key)
|
||||||
|
if grammar is not None:
|
||||||
|
return copy.copy(grammar)
|
||||||
|
grammar = self.initialize_grammar(key)
|
||||||
|
# If cache is full, remove the least recently used item
|
||||||
|
if len(self.request_key_to_grammar) >= self.max_cache_size:
|
||||||
|
self.request_key_to_grammar.popitem(last=False)
|
||||||
|
self.request_key_to_grammar[key] = grammar
|
||||||
|
return copy.copy(grammar)
|
||||||
|
|
||||||
|
def initialize_grammar(self, key: StructuredOutputKey) -> Grammar:
|
||||||
|
# Note that the request was validated in the engine core client,
|
||||||
|
# so at this point we know it is a supported type of request.
|
||||||
|
#
|
||||||
|
# TODO: we still need to handle xgrammar compilation failures
|
||||||
|
request_type, grammar_spec = key
|
||||||
|
|
||||||
|
if request_type == StructuredOutputOptions.JSON:
|
||||||
|
# TODO -- allow any_whitespace to be configurable
|
||||||
|
# pending merge of https://github.com/vllm-project/vllm/pull/12744
|
||||||
|
ctx = self.compiler.compile_json_schema(grammar_spec,
|
||||||
|
any_whitespace=False)
|
||||||
|
elif request_type == StructuredOutputOptions.JSON_OBJECT:
|
||||||
|
ctx = self.compiler.compile_builtin_json_grammar()
|
||||||
|
elif request_type == StructuredOutputOptions.GRAMMAR:
|
||||||
|
ctx = self.compiler.compile_grammar(grammar_spec)
|
||||||
|
else:
|
||||||
|
logger.error("Validation should have already occurred. "
|
||||||
|
"Please file an issue.")
|
||||||
|
raise ValueError(
|
||||||
|
f"grammar is not of valid supported types. ({request_type!s})")
|
||||||
|
|
||||||
|
return Grammar(
|
||||||
|
matcher=xgr.GrammarMatcher(ctx),
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
ctx=ctx,
|
||||||
|
)
|
||||||
|
|
||||||
|
def grammar_bitmask(
|
||||||
|
self,
|
||||||
|
requests: dict[str, Request],
|
||||||
|
structured_output_request_ids: dict[str, int],
|
||||||
|
batch_len: int,
|
||||||
|
) -> Optional[npt.NDArray[np.int32]]:
|
||||||
|
# Prepare the structured output bitmask for this batch.
|
||||||
|
if not structured_output_request_ids:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Fill the bitmask using the index of each request equal to its
|
||||||
|
# position in the batch. Resize the bitmask down to the size of
|
||||||
|
# the batch.
|
||||||
|
bitmask_tensor = self._grammar_bitmask
|
||||||
|
for req_id, batch_index in structured_output_request_ids.items():
|
||||||
|
request = requests[req_id].structured_output_request
|
||||||
|
assert request is not None and request.grammar is not None
|
||||||
|
if not request.grammar.matcher.is_terminated():
|
||||||
|
request.grammar.fill_bitmask(bitmask_tensor, batch_index)
|
||||||
|
if batch_len < self._grammar_bitmask.shape[0]:
|
||||||
|
bitmask_tensor = self._grammar_bitmask[:batch_len]
|
||||||
|
|
||||||
|
# After finishing with the xgrammar operations, we convert to
|
||||||
|
# np.ndarray, because that is much more efficient for serialization
|
||||||
|
# and deserialization when sending this to the GPU workers.
|
||||||
|
return bitmask_tensor.numpy()
|
77
vllm/v1/structured_output/grammar.py
Normal file
77
vllm/v1/structured_output/grammar.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import enum
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils import LazyLoader
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import xgrammar as xgr
|
||||||
|
else:
|
||||||
|
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class StructuredOutputOptions(enum.Enum):
|
||||||
|
JSON = enum.auto()
|
||||||
|
JSON_OBJECT = enum.auto()
|
||||||
|
REGEX = enum.auto()
|
||||||
|
GRAMMAR = enum.auto()
|
||||||
|
CHOICE = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
|
StructuredOutputKey = tuple[StructuredOutputOptions, str]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Grammar:
|
||||||
|
# NOTE: This would be a generic-enough class for
|
||||||
|
# supporting different backends, in the future.
|
||||||
|
# For now, just xgrammar.
|
||||||
|
#
|
||||||
|
# TODO: support max_rollback_tokens
|
||||||
|
# https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string
|
||||||
|
# for jump-forward decoding
|
||||||
|
|
||||||
|
vocab_size: int
|
||||||
|
matcher: xgr.GrammarMatcher = field(hash=False)
|
||||||
|
ctx: xgr.CompiledGrammar = field(hash=False)
|
||||||
|
num_processed_tokens: int = field(default_factory=lambda: 0,
|
||||||
|
repr=False,
|
||||||
|
hash=False,
|
||||||
|
init=False)
|
||||||
|
|
||||||
|
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
|
||||||
|
"""Accepts a list of tokens and advances the FSM.
|
||||||
|
|
||||||
|
Returns True if the FSM was advanced successfully.
|
||||||
|
Returns False if the FSM failed to advance.
|
||||||
|
"""
|
||||||
|
for token in tokens:
|
||||||
|
if not self.matcher.accept_token(token):
|
||||||
|
logger.error(
|
||||||
|
"Failed to advance FSM for request %s "
|
||||||
|
"for tokens %s. Please file an issue.", request_id, token)
|
||||||
|
return False
|
||||||
|
self.num_processed_tokens += 1
|
||||||
|
return True
|
||||||
|
|
||||||
|
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> bool:
|
||||||
|
return self.matcher.fill_next_token_bitmask(bitmask, idx)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.num_processed_tokens = 0
|
||||||
|
self.matcher.reset()
|
||||||
|
|
||||||
|
def __copy__(self):
|
||||||
|
return Grammar(
|
||||||
|
matcher=xgr.GrammarMatcher(self.ctx),
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
ctx=self.ctx,
|
||||||
|
)
|
71
vllm/v1/structured_output/request.py
Normal file
71
vllm/v1/structured_output/request.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import functools
|
||||||
|
import json
|
||||||
|
from concurrent.futures import Future
|
||||||
|
from concurrent.futures._base import TimeoutError
|
||||||
|
from typing import Optional, Union, cast
|
||||||
|
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.v1.structured_output.grammar import (Grammar, StructuredOutputKey,
|
||||||
|
StructuredOutputOptions)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class StructuredOutputRequest:
|
||||||
|
|
||||||
|
sampling_params: SamplingParams
|
||||||
|
_grammar: Optional[Union[Future[Grammar], Grammar]] = None
|
||||||
|
|
||||||
|
def _check_grammar_completion(self) -> bool:
|
||||||
|
# NOTE: We have to lazy import to gate circular imports
|
||||||
|
from vllm.v1.request import RequestStatus
|
||||||
|
|
||||||
|
if isinstance(self._grammar, Future):
|
||||||
|
try:
|
||||||
|
# We will check whether the future is ready within 100 us
|
||||||
|
self._grammar = self._grammar.result(timeout=0.0001)
|
||||||
|
self.status = RequestStatus.WAITING
|
||||||
|
except TimeoutError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_grammar_ready(self) -> bool:
|
||||||
|
return self._check_grammar_completion()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def grammar(self) -> Optional[Grammar]:
|
||||||
|
completed = self._check_grammar_completion()
|
||||||
|
return cast(Optional[Grammar], self._grammar) if completed else None
|
||||||
|
|
||||||
|
@grammar.setter
|
||||||
|
def grammar(self, grammar: Union[Grammar, Future[Grammar]]) -> None:
|
||||||
|
self._grammar = grammar
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def structured_output_key(self) -> StructuredOutputKey:
|
||||||
|
params = self.sampling_params.guided_decoding
|
||||||
|
assert params is not None, "params can't be None."
|
||||||
|
if params.json is not None:
|
||||||
|
if not isinstance(params.json, str):
|
||||||
|
json_str = json.dumps(params.json)
|
||||||
|
else:
|
||||||
|
json_str = params.json
|
||||||
|
return (StructuredOutputOptions.JSON, json_str)
|
||||||
|
elif params.json_object:
|
||||||
|
return (StructuredOutputOptions.JSON_OBJECT, "")
|
||||||
|
elif params.regex is not None:
|
||||||
|
return (StructuredOutputOptions.REGEX, params.regex)
|
||||||
|
elif params.choice is not None:
|
||||||
|
if not isinstance(params.choice, str):
|
||||||
|
json_str = json.dumps(params.choice)
|
||||||
|
else:
|
||||||
|
json_str = params.choice
|
||||||
|
return (StructuredOutputOptions.CHOICE, json_str)
|
||||||
|
elif params.grammar is not None:
|
||||||
|
return (StructuredOutputOptions.GRAMMAR, params.grammar)
|
||||||
|
else:
|
||||||
|
raise ValueError("No valid structured output parameter found")
|
295
vllm/v1/structured_output/utils.py
Normal file
295
vllm/v1/structured_output/utils.py
Normal file
@ -0,0 +1,295 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.utils import LazyLoader
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import xgrammar as xgr
|
||||||
|
else:
|
||||||
|
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||||
|
|
||||||
|
|
||||||
|
def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool:
|
||||||
|
"""Check if JSON schema contains features unsupported by xgrammar."""
|
||||||
|
|
||||||
|
def check_object(obj: dict[str, Any]) -> bool:
|
||||||
|
if not isinstance(obj, dict):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check for pattern restrictions
|
||||||
|
if "pattern" in obj:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check for enum restrictions
|
||||||
|
if "enum" in obj:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check for numeric ranges
|
||||||
|
if obj.get("type") in ("integer", "number") and any(
|
||||||
|
key in obj
|
||||||
|
for key in ("minimum", "maximum", "exclusiveMinimum",
|
||||||
|
"exclusiveMaximum", "multipleOf")):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check for array unsupported keywords
|
||||||
|
if obj.get("type") == "array" and any(
|
||||||
|
key in obj
|
||||||
|
for key in ("uniqueItems", "contains", "minContains",
|
||||||
|
"maxContains", "minItems", "maxItems")):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Unsupported keywords for strings
|
||||||
|
if obj.get("type") == "string" and any(
|
||||||
|
key in obj for key in ("minLength", "maxLength", "format")):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Unsupported keywords for objects
|
||||||
|
if obj.get("type") == "object" and any(
|
||||||
|
key in obj for key in ("minProperties", "maxProperties",
|
||||||
|
"propertyNames", "patternProperties")):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Recursively check all nested objects and arrays
|
||||||
|
for value in obj.values():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
if check_object(value):
|
||||||
|
return True
|
||||||
|
elif isinstance(value, list):
|
||||||
|
for item in value:
|
||||||
|
if isinstance(item, dict) and check_object(item):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
return check_object(schema)
|
||||||
|
|
||||||
|
|
||||||
|
def grammar_is_likely_lark(grammar_str: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if grammar appears to use Lark syntax.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grammar_str: Input grammar string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if grammar appears to be in Lark format, False otherwise
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> grammar_is_likely_lark("rule: 'abc'")
|
||||||
|
True
|
||||||
|
>>> grammar_is_likely_lark("rule ::= 'abc'")
|
||||||
|
False
|
||||||
|
"""
|
||||||
|
if not grammar_str or not isinstance(grammar_str, str):
|
||||||
|
return False
|
||||||
|
|
||||||
|
for line in grammar_str.split('\n'):
|
||||||
|
# Remove both comment styles
|
||||||
|
line = re.sub(r'(#|//).*$', '', line).strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Look for EBNF rule definition
|
||||||
|
if '::=' in line:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def convert_lark_to_ebnf(grammar_str: str) -> str:
|
||||||
|
"""
|
||||||
|
Convert a Lark grammar string to EBNF format.
|
||||||
|
|
||||||
|
EBNF reference:
|
||||||
|
https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
|
||||||
|
Lark grammar reference:
|
||||||
|
https://lark-parser.readthedocs.io/en/latest/grammar.html
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grammar_str: Input grammar in Lark format
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Converted grammar in EBNF format
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> print(convert_lark_to_ebnf("rule: 'hello'"))
|
||||||
|
root ::= rule
|
||||||
|
rule ::= "hello"
|
||||||
|
"""
|
||||||
|
if not isinstance(grammar_str, str):
|
||||||
|
raise ValueError(f"Grammar must be a string, got {type(grammar_str)}")
|
||||||
|
if not grammar_str.strip():
|
||||||
|
raise ValueError("Grammar string cannot be empty")
|
||||||
|
|
||||||
|
defined_rules = set()
|
||||||
|
referenced_rules = set()
|
||||||
|
output_lines = []
|
||||||
|
|
||||||
|
def clean_line(line: str) -> str:
|
||||||
|
"""Remove comments and whitespace from line."""
|
||||||
|
return re.sub(r'(#|//).*$', '', line).strip()
|
||||||
|
|
||||||
|
def check_quotes(text: str, rule_name: str, line_num: int) -> None:
|
||||||
|
"""Validate quote matching in text."""
|
||||||
|
if text.count("'") % 2 != 0 or text.count('"') % 2 != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Mismatched quotes in {rule_name} on line {line_num}")
|
||||||
|
|
||||||
|
def extract_references(text: str) -> set:
|
||||||
|
"""Extract rule references from text."""
|
||||||
|
# Remove quoted strings and special characters
|
||||||
|
text = re.sub(r'"[^"]*"', '', text)
|
||||||
|
text = re.sub(r'[+*?()|\[\]{}]', ' ', text)
|
||||||
|
return set(re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', text))
|
||||||
|
|
||||||
|
# First pass: Find root rule and validate rule definitions
|
||||||
|
lines = [clean_line(line) for line in grammar_str.split('\n')]
|
||||||
|
first_rule = None
|
||||||
|
|
||||||
|
for line_num, line in enumerate(lines, 1):
|
||||||
|
if not line or line.startswith('|'):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ':' in line:
|
||||||
|
try:
|
||||||
|
name = line.split(':', 1)[0].strip().strip('?')
|
||||||
|
defined_rules.add(name)
|
||||||
|
if first_rule is None:
|
||||||
|
first_rule = name
|
||||||
|
if name == 'start':
|
||||||
|
first_rule = 'start'
|
||||||
|
except IndexError as e:
|
||||||
|
raise ValueError(f"Invalid rule format on line {line_num}. "
|
||||||
|
"Expected 'rule_name: definition'") from e
|
||||||
|
|
||||||
|
if not defined_rules:
|
||||||
|
raise ValueError("No valid rules found in grammar")
|
||||||
|
|
||||||
|
# Add root rule
|
||||||
|
output_lines.append(f"root ::= {first_rule}")
|
||||||
|
|
||||||
|
# Second pass: Process rule definitions and alternatives
|
||||||
|
current_rule = None
|
||||||
|
current_definition = []
|
||||||
|
|
||||||
|
for line_num, line in enumerate(lines, 1):
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
if ':' in line and not line.startswith('|'):
|
||||||
|
# Save previous rule if exists
|
||||||
|
if current_rule:
|
||||||
|
output_lines.append(
|
||||||
|
f"{current_rule} ::= {' | '.join(current_definition)}")
|
||||||
|
|
||||||
|
# Process new rule
|
||||||
|
name, definition = line.split(':', 1)
|
||||||
|
current_rule = name.strip().strip('?')
|
||||||
|
|
||||||
|
check_quotes(definition, f"rule '{current_rule}'", line_num)
|
||||||
|
definition = re.sub(r"'([^']*)'", r'"\1"', definition)
|
||||||
|
referenced_rules.update(extract_references(definition))
|
||||||
|
current_definition = [definition.strip()]
|
||||||
|
|
||||||
|
elif line.startswith('|'):
|
||||||
|
if not current_rule:
|
||||||
|
raise ValueError(f"Alternative '|' on line {line_num} "
|
||||||
|
"without a preceding rule definition")
|
||||||
|
|
||||||
|
alt_def = line[1:].strip()
|
||||||
|
check_quotes(alt_def, f"alternative for rule '{current_rule}'",
|
||||||
|
line_num)
|
||||||
|
alt_def = re.sub(r"'([^']*)'", r'"\1"', alt_def)
|
||||||
|
referenced_rules.update(extract_references(alt_def))
|
||||||
|
current_definition.append(alt_def)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(f"Error on line {line_num}: {str(e)}") from e
|
||||||
|
|
||||||
|
# Add final rule if exists
|
||||||
|
if current_rule:
|
||||||
|
output_lines.append(
|
||||||
|
f"{current_rule} ::= {' | '.join(current_definition)}")
|
||||||
|
|
||||||
|
# Validate all rules are defined
|
||||||
|
undefined_rules = referenced_rules - defined_rules - {'root'}
|
||||||
|
if undefined_rules:
|
||||||
|
raise ValueError("Referenced rules are not defined: "
|
||||||
|
f"{', '.join(sorted(undefined_rules))}")
|
||||||
|
|
||||||
|
return '\n'.join(output_lines)
|
||||||
|
|
||||||
|
|
||||||
|
def choice_as_grammar(choice: list[str]) -> str:
|
||||||
|
|
||||||
|
def escape_ebnf_string(s: str) -> str:
|
||||||
|
"""Escape special characters in a EBNF string."""
|
||||||
|
# Escape double quotes and backslashes
|
||||||
|
return re.sub(r'(["\\])', r'\\\1', s)
|
||||||
|
|
||||||
|
escaped_choices = (escape_ebnf_string(c) for c in choice)
|
||||||
|
grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices))
|
||||||
|
return grammar
|
||||||
|
|
||||||
|
|
||||||
|
def validate_structured_output_request(
|
||||||
|
sampling_params: SamplingParams) -> None:
|
||||||
|
"""Validate that the request is supported by structured output.
|
||||||
|
|
||||||
|
Raises ValueError if the request is not supported.
|
||||||
|
"""
|
||||||
|
if sampling_params.guided_decoding is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
gd_params = sampling_params.guided_decoding
|
||||||
|
|
||||||
|
if gd_params.regex:
|
||||||
|
raise ValueError("Regex structured output is not supported.")
|
||||||
|
|
||||||
|
if gd_params.choice:
|
||||||
|
choice_grammar = choice_as_grammar(gd_params.choice)
|
||||||
|
try:
|
||||||
|
xgr.Grammar.from_ebnf(choice_grammar)
|
||||||
|
except Exception as err:
|
||||||
|
raise ValueError("Failed to transform choices into a grammar: "
|
||||||
|
"{err}") from err
|
||||||
|
gd_params.choice = None
|
||||||
|
gd_params.grammar = choice_grammar
|
||||||
|
return
|
||||||
|
|
||||||
|
if gd_params.json:
|
||||||
|
if isinstance(gd_params.json, str):
|
||||||
|
try:
|
||||||
|
schema = json.loads(gd_params.json)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError("Invalid JSON grammar specification.") from e
|
||||||
|
else:
|
||||||
|
schema = gd_params.json
|
||||||
|
|
||||||
|
if has_xgrammar_unsupported_json_features(schema):
|
||||||
|
raise ValueError("The provided JSON schema contains features not "
|
||||||
|
"supported by xgrammar.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if gd_params.grammar:
|
||||||
|
if grammar_is_likely_lark(gd_params.grammar):
|
||||||
|
# xgrammar supports EBNF grammars only
|
||||||
|
try:
|
||||||
|
gd_params.grammar = convert_lark_to_ebnf(gd_params.grammar)
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(
|
||||||
|
"Failed to convert the grammar from Lark to EBNF. ") from e
|
||||||
|
|
||||||
|
# Test parsing EBNF grammar, possibly already converted from Lark
|
||||||
|
try:
|
||||||
|
# parse the grammar, but we aren't compiling it.
|
||||||
|
xgr.Grammar.from_ebnf(gd_params.grammar)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError("Invalid grammar specification.") from e
|
@ -25,7 +25,8 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
|
|||||||
from vllm.sampling_params import SamplingType
|
from vllm.sampling_params import SamplingType
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||||
LayerBlockType, cdiv, is_pin_memory_available)
|
LayerBlockType, LazyLoader, cdiv,
|
||||||
|
is_pin_memory_available)
|
||||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||||
from vllm.v1.engine.mm_input_cache import MMInputCacheClient
|
from vllm.v1.engine.mm_input_cache import MMInputCacheClient
|
||||||
@ -40,7 +41,11 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
|||||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
import xgrammar as xgr
|
||||||
|
|
||||||
from vllm.v1.core.scheduler_output import SchedulerOutput
|
from vllm.v1.core.scheduler_output import SchedulerOutput
|
||||||
|
else:
|
||||||
|
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -860,6 +865,53 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
def get_model(self) -> nn.Module:
|
def get_model(self) -> nn.Module:
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
|
def apply_grammar_bitmask(
|
||||||
|
self,
|
||||||
|
scheduler_output: "SchedulerOutput",
|
||||||
|
logits: torch.Tensor,
|
||||||
|
):
|
||||||
|
# Serialization of np.ndarray is much more efficient than a tensor,
|
||||||
|
# so we receive it in that format.
|
||||||
|
grammar_bitmask = scheduler_output.grammar_bitmask
|
||||||
|
if grammar_bitmask is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# We receive the structured output bitmask from the scheduler, but the
|
||||||
|
# indices of the requests in the batch may not match the indices of
|
||||||
|
# the bitmask since the scheduler doesn't know how the gpu runner is
|
||||||
|
# ordering the requests in the batch. We need to sort the bitmask to
|
||||||
|
# match the order of the requests used here.
|
||||||
|
struct_out_req_batch_indices: dict[str, int] = {}
|
||||||
|
indices_match = True
|
||||||
|
for req_id in self.input_batch.req_ids:
|
||||||
|
mask_index = scheduler_output.structured_output_request_ids.get(
|
||||||
|
req_id)
|
||||||
|
if mask_index is None:
|
||||||
|
# not a structured output request
|
||||||
|
continue
|
||||||
|
batch_index = self.input_batch.req_id_to_index[req_id]
|
||||||
|
if batch_index != mask_index:
|
||||||
|
indices_match = False
|
||||||
|
struct_out_req_batch_indices[req_id] = batch_index
|
||||||
|
|
||||||
|
if not indices_match:
|
||||||
|
# Sort the bitmask to match the order of the requests
|
||||||
|
sorted_bitmask = np.zeros_like(grammar_bitmask)
|
||||||
|
for req_id, batch_index in struct_out_req_batch_indices.items():
|
||||||
|
orig_index = scheduler_output.structured_output_request_ids[
|
||||||
|
req_id]
|
||||||
|
sorted_bitmask[batch_index] = grammar_bitmask[orig_index]
|
||||||
|
grammar_bitmask = sorted_bitmask
|
||||||
|
|
||||||
|
grammar_bitmask = torch.from_numpy(grammar_bitmask)
|
||||||
|
|
||||||
|
# TODO: compatibility with spec decode
|
||||||
|
xgr.apply_token_bitmask_inplace(
|
||||||
|
logits,
|
||||||
|
grammar_bitmask.to(self.device, non_blocking=True),
|
||||||
|
indices=list(struct_out_req_batch_indices.values()),
|
||||||
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
@ -945,6 +997,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
sample_hidden_states = hidden_states[logits_indices]
|
sample_hidden_states = hidden_states[logits_indices]
|
||||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||||
|
|
||||||
|
# Apply structured output bitmasks if present
|
||||||
|
if scheduler_output.grammar_bitmask is not None:
|
||||||
|
self.apply_grammar_bitmask(scheduler_output, logits)
|
||||||
|
|
||||||
# Sample the next token and get logprobs if needed.
|
# Sample the next token and get logprobs if needed.
|
||||||
sampling_metadata = self.input_batch.sampling_metadata
|
sampling_metadata = self.input_batch.sampling_metadata
|
||||||
if not self.use_spec_decode:
|
if not self.use_spec_decode:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user