[V1][Core] Support for Structured Outputs (#12388)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Aaron Pham 2025-03-07 10:19:11 -05:00 committed by GitHub
parent 1e3598edeb
commit 80e9afb5bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 1528 additions and 715 deletions

View File

@ -204,6 +204,7 @@ steps:
- VLLM_USE_V1=1 pytest -v -s v1/engine
- VLLM_USE_V1=1 pytest -v -s v1/sample
- VLLM_USE_V1=1 pytest -v -s v1/worker
- VLLM_USE_V1=1 pytest -v -s v1/structured_output
- VLLM_USE_V1=1 pytest -v -s v1/test_stats.py
- VLLM_USE_V1=1 pytest -v -s v1/test_utils.py
# TODO: accuracy does not match, whether setting

2
.gitignore vendored
View File

@ -197,7 +197,7 @@ _build/
hip_compat.h
# Benchmark dataset
benchmarks/*.json
benchmarks/**/*.json
# Linting
actionlint

View File

@ -1,507 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
"""Benchmark guided decoding throughput."""
import argparse
import dataclasses
import json
import os
import random
import time
import datasets
import pandas as pd
import uvloop
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
from vllm.sampling_params import GuidedDecodingParams
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
@dataclasses.dataclass
class SampleRequest:
"""A class representing a single inference request for benchmarking.
Attributes:
prompt: The input text prompt for the model.
multi_modal_data: Optional dictionary containing multi-modal data (e.g.
images).
prompt_len: The length of the prompt in tokens.
expected_output_len: The expected length of the output in tokens.
"""
prompt: str
prompt_len: int
expected_output_len: int
schema: dict
structure_type: str = 'json'
completion: str = None
def run_vllm(requests: list[SampleRequest],
engine_args: EngineArgs,
n: int,
guided_decoding_rate: float = 1.0,
warmup: bool = False) -> float:
from vllm import LLM, SamplingParams
llm = LLM(**vars(engine_args))
assert all(
llm.llm_engine.model_config.max_model_len >= (
request.prompt_len + request.expected_output_len)
for request in requests), (
"Please ensure that max_model_len is greater than the sum of"
" prompt_len and expected_output_len for all requests.")
# Add the requests to the engine.
prompts: list[str] = []
sampling_params: list[SamplingParams] = []
# create a list containing random selected true or false
guided_decoding_req_idx = random.sample(
range(len(requests)), int(len(requests) * guided_decoding_rate))
if warmup:
print(">>>>> Running warmup prompt, for the first 5")
# We setup the first 5 requests to warmup FSM
# if using xgrammar dataset, we will skip warmup
warmup_requests = requests[:5]
for i, request in enumerate(warmup_requests):
prompts.append(request.prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=request.expected_output_len,
guided_decoding=GuidedDecodingParams(json=request.schema)
if guided_decoding_rate > 0 else None,
))
llm.generate(prompts, sampling_params, use_tqdm=False)
print(">>>>> Benchmark started...")
prompts = []
sampling_params = []
for i, request in enumerate(requests):
prompts.append(request.prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=request.expected_output_len,
guided_decoding=GuidedDecodingParams(
**{request.structure_type: request.schema})
if i in guided_decoding_req_idx else None,
))
start = time.perf_counter()
outputs = llm.generate(prompts, sampling_params, use_tqdm=False)
ret = []
for output, request in zip(outputs, requests):
generated_text = output.outputs[0].text
ret.append({
"generated": generated_text,
"expected": request.completion
})
end = time.perf_counter()
return end - start, ret
async def run_vllm_async(
requests: list[SampleRequest],
engine_args: AsyncEngineArgs,
n: int,
guided_decoding_rate: float = 1.0,
warmup: bool = False,
disable_frontend_multiprocessing: bool = False) -> float:
from vllm import SamplingParams
async with build_async_engine_client_from_engine_args(
engine_args, disable_frontend_multiprocessing) as llm:
assert all(
llm.model_config.max_model_len >= (request.prompt_len +
request.expected_output_len)
for request in requests), (
"Please ensure that max_model_len is greater than the sum of"
" prompt_len and expected_output_len for all requests.")
# Add the requests to the engine.
prompts: list[str] = []
sampling_params: list[SamplingParams] = []
guided_decoding_req_idx = random.sample(
range(len(requests)), int(len(requests) * guided_decoding_rate))
if warmup:
print(">>>>>> Running warmup prompt, for the first 5")
# We setup the first 5 requests to warmup FSM
# if using xgrammar dataset, we will skip warmup
warmup_requests = requests[:5]
for i, request in enumerate(warmup_requests):
prompts.append(request.prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=request.expected_output_len,
guided_decoding=GuidedDecodingParams(
json=request.schema)
if guided_decoding_rate > 0 else None,
))
generators = []
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
generator = llm.generate(prompt, sp, request_id=f"test{i}")
generators.append(generator)
all_gens = merge_async_iterators(*generators)
async for i, res in all_gens:
pass
print(">>>>> Benchmark started...")
prompts = []
sampling_params = []
for i, request in enumerate(requests):
prompts.append(request.prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=1.0,
top_p=1.0,
ignore_eos=True,
max_tokens=request.expected_output_len,
guided_decoding=GuidedDecodingParams(json=request.schema)
if i in guided_decoding_req_idx else None,
))
generators = []
start_time = []
latencies = []
start = time.perf_counter()
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
generator = llm.generate(prompt, sp, request_id=f"test{i}")
generators.append(generator)
start_time.append(time.perf_counter())
latencies.append([])
all_gens = merge_async_iterators(*generators)
generated_texts = [''] * len(requests)
async for i, res in all_gens:
generated_texts[i] = res.outputs[0].text
lat = time.perf_counter() - start_time[i]
latencies[i].append(lat)
ret = [{
'generated': gt,
'expected': req.completion
} for gt, req in zip(generated_texts, requests)]
end = time.perf_counter()
first_latency = pd.Series([lat[0] * 1000 for lat in latencies])
next_latency = pd.Series([(lat[-1] - lat[0]) / len(lat[1:]) * 1000
for lat in latencies])
return end - start, ret, (first_latency, next_latency)
def sample_requests(tokenizer: PreTrainedTokenizerBase,
args: argparse.Namespace) -> list[SampleRequest]:
if args.dataset == 'json':
if args.json_schema_path is None:
dir_path = os.path.dirname(os.path.realpath(__file__))
args.json_schema_path = os.path.join(dir_path,
"structured_schemas",
"structured_schema_1.json")
with open(args.json_schema_path) as f:
schema = json.load(f)
prompt = f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501
input_len = len(tokenizer(prompt).input_ids)
print(f"Input length of the prompt: {input_len} tokens")
requests = [
SampleRequest(prompt=prompt,
prompt_len=input_len,
expected_output_len=args.output_len,
schema=schema,
structure_type=args.structure_type)
for _ in range(args.num_prompts)
]
elif args.dataset == "grammar":
schema = """
?start: select_statement
?select_statement: "SELECT " column_list " FROM " table_name
?column_list: column_name ("," column_name)*
?table_name: identifier
?column_name: identifier
?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/
"""
prompt = "Generate an SQL query to show the 'username' \
and 'email' from the 'users' table."
input_len = len(tokenizer(prompt).input_ids)
print(f"Input length of the prompt: {input_len} tokens")
requests = [
SampleRequest(prompt=prompt,
prompt_len=input_len,
expected_output_len=args.output_len,
schema=schema,
structure_type=args.structure_type)
for _ in range(args.num_prompts)
]
elif args.dataset == "regex":
regex = r"\w+@\w+\.com\n"
args.regex = regex
prompt = "Generate an email address for Alan Turing, \
who works in Enigma. End in .com and new line. \
Example result: alan.turing@enigma.com\n"
input_len = len(tokenizer(prompt).input_ids)
print(f"Input length of the prompt: {input_len} tokens")
requests = [
SampleRequest(prompt=prompt,
prompt_len=input_len,
expected_output_len=args.output_len,
schema=regex,
structure_type=args.structure_type)
for _ in range(args.num_prompts)
]
elif args.dataset == "choice":
choice = ["Positive", "Negative"]
args.choice = choice
prompt = "Classify this sentiment: vLLM is wonderful!"
input_len = len(tokenizer(prompt).input_ids)
print(f"Input length of the prompt: {input_len} tokens")
requests = [
SampleRequest(prompt=prompt,
prompt_len=input_len,
expected_output_len=args.output_len,
schema=choice,
structure_type=args.structure_type)
for _ in range(args.num_prompts)
]
elif args.dataset == "xgrammar_bench":
args.warmup = False
requests: list[SampleRequest] = []
dataset = datasets.load_dataset("NousResearch/json-mode-eval",
split="train")
print(f"dataset has {len(dataset)} entries")
len_dataset = len(dataset)
for data_point_idx in range(args.num_prompts):
idx = data_point_idx
while idx >= len_dataset:
idx -= len_dataset
schema = dataset["schema"][idx]
prompt = tokenizer.apply_chat_template(dataset["prompt"][idx],
tokenize=False)
input_len = len(tokenizer(prompt).input_ids)
completion = dataset["completion"][idx]
requests.append(
SampleRequest(prompt=prompt,
prompt_len=input_len,
expected_output_len=args.output_len,
schema=schema,
completion=completion))
return requests
def evaluate(ret, args):
def _eval_correctness_json(expected, actual):
# extract json string from string using regex
import re
actual = actual.replace('\n', '').replace(' ', '').strip()
try:
actual = re.search(r'\{.*\}', actual).group()
actual = json.loads(actual)
except Exception:
return False
return True
def _eval_correctness_choice(expected, actual):
return actual in args.choice
def _eval_correctness_regex(expected, actual):
import re
return re.match(args.regex, actual) is not None
def _eval_correctness(expected, actual):
if args.structure_type == 'json':
return _eval_correctness_json(expected, actual)
elif args.structure_type == 'regex':
return _eval_correctness_regex(expected, actual)
elif args.structure_type == 'choice':
return _eval_correctness_choice(expected, actual)
else:
return None
scores = []
for res in ret:
score = _eval_correctness(res['expected'], res['generated'])
res['correctness'] = score
scores.append(score)
not_none_scores = [score for score in scores if score is not None]
return (sum(not_none_scores) / len(not_none_scores) *
100) if len(not_none_scores) > 0 else None
def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
# async engine is working for 'regex', 'choice' and 'grammar'
if args.dataset == 'grammar':
args.structure_type = 'grammar'
args.async_engine = False
elif args.dataset == 'regex':
args.structure_type = 'regex'
args.async_engine = False
elif args.dataset == 'choice':
args.structure_type = 'choice'
args.async_engine = False
else:
args.structure_type = 'json'
if args.no_guided_decoding:
args.guided_decoding_ratio = 0
if args.save_results:
result_file_name = f'{args.guided_decoding_ratio}guided'
result_file_name += f"_{args.model.split('/')[-1]}"
result_file_name += f"_{args.dataset}"
result_file_name += f"_{args.num_prompts}"
result_file_name += f"_out{args.output_len}"
result_file_name += f"_async{args.async_engine}"
result_file_name += f"_warmup{args.warmup}"
result_file_name += f"_chunkedprefill{args.enable_chunked_prefill}"
result_file_name += ".txt"
else:
result_file_name = None
# Synthesize a prompt with the given input length.
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code)
requests = sample_requests(tokenizer, args)
if args.async_engine:
engine_args = AsyncEngineArgs.from_cli_args(args)
elapsed_time, ret, (first_latency, next_latency) = uvloop.run(
run_vllm_async(requests, engine_args, args.n,
args.guided_decoding_ratio, args.warmup,
args.disable_frontend_multiprocessing))
else:
engine_args = EngineArgs.from_cli_args(args)
elapsed_time, ret = run_vllm(requests, engine_args, args.n,
args.guided_decoding_ratio, args.warmup)
first_latency, next_latency = None, None
score = evaluate(ret, args)
total_num_tokens = sum(request.prompt_len + request.expected_output_len
for request in requests)
total_output_tokens = sum(request.expected_output_len
for request in requests)
if first_latency is not None:
latency_breakdown = "\nFirst token latency(msecs):\n"
latency_breakdown += f"{first_latency.describe()}"
latency_breakdown += "\nNext token latency(msecs):\n"
latency_breakdown += f"{next_latency.describe()}"
print(
f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
f"{total_output_tokens / elapsed_time:.2f} output tokens/s",
f"Correct rate is {score} %",
f"{latency_breakdown if first_latency is not None else ''}")
# Output JSON results if specified
if args.output_json or result_file_name:
results = {
"elapsed_time": elapsed_time,
"num_requests": len(requests),
"total_num_tokens": total_num_tokens,
"total_output_tokens": total_output_tokens,
"requests_per_second": len(requests) / elapsed_time,
"tokens_per_second": f"{total_num_tokens / elapsed_time:.2f}",
"output_tokens_per_second":
f"{total_output_tokens / elapsed_time:.2f}",
"correct_rate(%)": score
}
results = {"outputs": ret, **results}
if first_latency is not None:
results["first_token_latency(msecs)"] = first_latency.describe(
).to_dict()
results["next_token_latency(msecs)"] = next_latency.describe(
).to_dict()
if args.output_json:
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
elif result_file_name:
with open(result_file_name, "w") as f:
json.dump(results, f, indent=4)
if __name__ == "__main__":
parser = FlexibleArgumentParser(description="Benchmark guided decoding.")
parser = AsyncEngineArgs.add_cli_args(parser)
parser.add_argument("--output-len",
type=int,
default=512,
help="Output length for each request. Overrides the "
"output length from the dataset.")
parser.add_argument(
"--dataset",
default='json',
choices=['json', 'grammar', 'regex', 'choice', 'xgrammar_bench'])
parser.add_argument("--json_schema_path",
type=str,
default=None,
help="Path to json schema.")
parser.add_argument("--n",
type=int,
default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--num-prompts",
type=int,
default=10,
help="Number of prompts to process.")
parser.add_argument(
'--output-json',
type=str,
default=None,
help='Path to save the throughput results in JSON format.')
parser.add_argument("--async-engine",
action='store_true',
default=False,
help="Use vLLM async engine rather than LLM class.")
parser.add_argument("--no-guided-decoding",
action='store_true',
default=False,
help="Whether to disable JSON decoding or not.")
parser.add_argument("--guided-decoding-ratio",
type=float,
default=1.0,
help="Ratio of Guided Decoding requests")
parser.add_argument("--disable-frontend-multiprocessing",
action='store_true',
default=False,
help="Disable decoupled async engine frontend.")
parser.add_argument("--warmup",
action="store_true",
default=False,
help="Run warmup prompts before benchmark.")
parser.add_argument("--save-results",
action="store_true",
default=False,
help="save output results.")
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
main(args)

View File

@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
r"""Benchmark online serving throughput with guided decoding.
r"""Benchmark online serving throughput with structured outputs.
On the server side, run one of the following commands:
(vLLM OpenAI API server)
@ -9,12 +9,12 @@ On the server side, run one of the following commands:
./launch_tgi_server.sh <your_model> <max_batch_total_tokens>
On the client side, run:
python benchmarks/benchmark_serving_guided.py \
python benchmarks/benchmark_serving_structured_output.py \
--backend <backend> \
--model <your_model> \
--dataset json \
--guided-decoding-ratio 1.0 \
--guided-decoding-backend xgrammar \
--structured-output-ratio 1.0 \
--structured-output-backend xgrammar \
--request-rate 10 \
--num-prompts 1000
@ -52,6 +52,9 @@ try:
except ImportError:
from argparse import ArgumentParser as FlexibleArgumentParser
from vllm.v1.structured_output.utils import (
has_xgrammar_unsupported_json_features)
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
@ -191,7 +194,17 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
requests: list[SampleRequest] = []
dataset = datasets.load_dataset("NousResearch/json-mode-eval",
split="train")
print(f"dataset has {len(dataset)} entries")
full_dataset_len = len(dataset)
def _filter_func(item):
import json
schema = json.loads(item["schema"])
return not has_xgrammar_unsupported_json_features(schema)
dataset = dataset.filter(_filter_func)
num_filtered_out = full_dataset_len - len(dataset)
print(f"dataset has {len(dataset)} entries after filtering "
f"out {num_filtered_out} entries with unsupported features")
len_dataset = len(dataset)
for data_point_idx in range(args.num_prompts):
idx = data_point_idx
@ -378,8 +391,8 @@ async def benchmark(
selected_percentiles: list[str],
ignore_eos: bool,
max_concurrency: Optional[int],
guided_decoding_ratio: float,
guided_decoding_backend: str,
structured_output_ratio: float,
structured_output_backend: str,
goodput_config_dict: Optional[dict[str, float]] = None,
):
if backend in ASYNC_REQUEST_FUNCS:
@ -391,16 +404,18 @@ async def benchmark(
extra_body = {}
# Add the schema to the extra_body
extra_body[request.structure_type] = request.schema
# Add the specific guided_decoding_backend
extra_body["guided_decoding_backend"] = guided_decoding_backend
# Add the specific structured_output_backend
extra_body["guided_decoding_backend"] = structured_output_backend
return extra_body
print("Starting initial single prompt test run...")
guided_decoding_req_idx = random.sample(
structured_output_req_idx = random.sample(
range(len(input_requests)),
int(len(input_requests) * guided_decoding_ratio))
int(len(input_requests) * structured_output_ratio))
test_request = input_requests[0]
test_req_extra_body = (prepare_extra_body(test_request)
if 0 in structured_output_req_idx else None)
test_input = RequestFuncInput(
model=model_id,
prompt=test_request.prompt,
@ -408,7 +423,7 @@ async def benchmark(
prompt_len=test_request.prompt_len,
output_len=test_request.expected_output_len,
ignore_eos=ignore_eos,
extra_body=prepare_extra_body(test_request),
extra_body=test_req_extra_body,
)
test_output = await request_func(request_func_input=test_input)
if not test_output.success:
@ -427,7 +442,7 @@ async def benchmark(
prompt_len=test_request.prompt_len,
output_len=test_request.expected_output_len,
ignore_eos=ignore_eos,
extra_body=prepare_extra_body(test_request),
extra_body=test_req_extra_body,
)
profile_output = await request_func(request_func_input=profile_input)
if profile_output.success:
@ -465,7 +480,7 @@ async def benchmark(
async for i, request in get_request(input_requests, request_rate,
burstiness):
extra_body = prepare_extra_body(
request) if i in guided_decoding_req_idx else None
request) if i in structured_output_req_idx else None
request_func_input = RequestFuncInput(
model=model_id,
prompt=request.prompt,
@ -708,10 +723,10 @@ def main(args: argparse.Namespace):
else:
args.structure_type = 'guided_json'
if args.no_guided_decoding:
args.guided_decoding_ratio = 0
if args.no_structured_output:
args.structured_output_ratio = 0
if args.save_results:
result_file_name = f'{args.guided_decoding_ratio}guided'
result_file_name = f'{args.structured_output_ratio}guided'
result_file_name += f"_{backend}"
result_file_name += f"_{args.request_rate}qps"
result_file_name += f"_{args.model.split('/')[-1]}"
@ -744,8 +759,8 @@ def main(args: argparse.Namespace):
],
ignore_eos=args.ignore_eos,
max_concurrency=args.max_concurrency,
guided_decoding_ratio=args.guided_decoding_ratio,
guided_decoding_backend=args.guided_decoding_backend,
structured_output_ratio=args.structured_output_ratio,
structured_output_backend=args.structured_output_backend,
goodput_config_dict=goodput_config_dict,
))
@ -943,19 +958,19 @@ if __name__ == "__main__":
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
"and the blog: https://hao-ai-lab.github.io/blogs/distserve")
parser.add_argument("--no-guided-decoding",
parser.add_argument("--no-structured-output",
action='store_true',
default=False,
help="Whether to disable JSON decoding or not.")
parser.add_argument("--guided-decoding-ratio",
parser.add_argument("--structured-output-ratio",
type=float,
default=1.0,
help="Ratio of Guided Decoding requests")
parser.add_argument("--guided-decoding-backend",
help="Ratio of Structured Outputs requests")
parser.add_argument("--structured-output-backend",
type=str,
choices=["outlines", "lm-format-enforcer", "xgrammar"],
default="xgrammar",
help="Backend to use for guided decoding")
help="Backend to use for structured outputs")
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,64 @@
#!/bin/bash
# Define the model to use
MODEL=${1:-"Qwen/Qwen2.5-7B-Instruct"}
# Define the backend to use
BACKEND=${2:-"vllm"}
# Define the dataset to use
DATASET=${3:-"xgrammar_bench"}
# Define the guided decoding backend
GUIDED_BACKEND=${4:-"xgrammar"}
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
OUTPUT_DIR=${5:-"$SCRIPT_DIR/structured_output_benchmark_results"}
GUIDED_RATIO=${6:-0.5}
# Create output directory if it doesn't exist
mkdir -p "$OUTPUT_DIR"
# Define QPS values to test
QPS_VALUES=(70 60 50 25 20 15 10)
# Common parameters
COMMON_PARAMS="--backend $BACKEND \
--model $MODEL \
--dataset $DATASET \
--structured-output-backend $GUIDED_BACKEND \
--structured-output-ratio $GUIDED_RATIO \
--save-results \
--result-dir $OUTPUT_DIR"
echo "Starting structured output benchmark with model: $MODEL"
echo "Backend: $BACKEND"
echo "Dataset: $DATASET"
echo "Structured output backend: $GUIDED_BACKEND"
echo "Results will be saved to: $OUTPUT_DIR"
echo "----------------------------------------"
# Run benchmarks with different QPS values
for qps in "${QPS_VALUES[@]}"; do
echo "Running benchmark with QPS: $qps"
# Get git hash and branch for the filename
GIT_HASH=$(git rev-parse --short HEAD 2>/dev/null || echo "unknown")
GIT_BRANCH=$(git rev-parse --abbrev-ref HEAD 2>/dev/null || echo "unknown")
# Construct filename for this run
FILENAME="${GUIDED_BACKEND}_${BACKEND}_${qps}qps_$(basename $MODEL)_${DATASET}_${GIT_HASH}.json"
# Run the benchmark
python "$SCRIPT_DIR/benchmark_serving_structured_output.py" $COMMON_PARAMS \
--request-rate $qps \
--result-filename "$FILENAME" \
--port ${PORT:-8000}
echo "Completed benchmark with QPS: $qps"
echo "----------------------------------------"
done
echo "All benchmarks completed!"
echo "Results saved to: $OUTPUT_DIR"

View File

@ -1,113 +1,25 @@
{
"$schema":
"https://json-schema.org/draft/2020-12/schema",
"title":
"User Profile",
"type":
"object",
"properties": {
"userId": {
"type": "string",
"description": "Unique identifier for the user."
},
"personalInfo": {
"type": "object",
"properties": {
"firstName": {
"type": "string",
"description": "The user's first name."
},
"lastName": {
"type": "string",
"description": "The user's last name."
},
"age": {
"type": "integer",
"minimum": 0,
"description": "The user's age."
},
"phoneNumbers": {
"type":
"array",
"items": {
"type": "object",
"properties": {
"type": {
"type": "string",
"enum": ["home", "work", "mobile"],
"description": "Type of phone number."
},
"number": {
"type": "string",
"pattern": "^\\+?[1-9]\\d{1,14}$",
"description": "Phone number in E.164 format."
}
},
"required": ["type", "number"]
},
"description":
"List of phone numbers associated with the user."
}
},
"required": ["firstName", "lastName"]
},
"address": {
"type": "object",
"properties": {
"street": {
"type": "string",
"description": "Street address."
},
"city": {
"type": "string",
"description": "City name."
},
"state": {
"type": "string",
"description": "State or province."
},
"postalCode": {
"type": "string",
"pattern": "^\\d{5}(-\\d{4})?$",
"description": "Postal code."
},
"country": {
"type": "string",
"description": "Country name."
}
},
"required": ["street", "city", "state", "postalCode", "country"]
},
"preferences": {
"type": "object",
"properties": {
"newsletterSubscribed": {
"type":
"boolean",
"description":
"Indicates if the user is subscribed to the newsletter."
},
"favoriteCategories": {
"type": "array",
"items": {
"type": "string"
"type": "object",
"properties": {
"name": { "type": "string" },
"race": { "type": "string" },
"class": { "type": "string" },
"level": { "type": "integer" },
"background": { "type": "string" },
"alignment": { "type": "string" },
"backstory": { "type": "string" }
},
"description": "List of user's favorite categories."
"required": [
"name",
"race",
"class",
"level",
"background",
"alignment",
"backstory"
]
}
},
"required": ["newsletterSubscribed"]
},
"accountStatus": {
"type": "string",
"enum": ["active", "inactive", "suspended"],
"description": "Current status of the user's account."
},
"registrationDate": {
"type": "string",
"format": "date-time",
"description": "ISO 8601 formatted date-time of user registration."
}
},
"required":
["userId", "personalInfo", "address", "accountStatus", "registrationDate"]
}

View File

@ -1,12 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
EOS_TOKEN_ID = 50256
@ -36,13 +37,21 @@ def create_scheduler(
swap_space=0,
cache_dtype="auto",
)
vllm_config = VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
)
cache_config.num_gpu_blocks = 10000
return Scheduler(scheduler_config,
return Scheduler(
scheduler_config,
model_config,
cache_config,
speculative_config=None,
lora_config=None,
log_stats=True)
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
)
def create_requests(
@ -249,7 +258,9 @@ def test_stop_via_update_from_output():
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[])
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
@ -299,7 +310,9 @@ def test_stop_via_update_from_output():
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[])
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
@ -347,7 +360,9 @@ def test_stop_via_update_from_output():
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[])
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
@ -392,7 +407,9 @@ def test_stop_via_update_from_output():
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[])
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],

View File

@ -29,6 +29,7 @@ def sample_regex():
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
# Note: Ensure this only uses attributes compatible with xgrammar
@pytest.fixture
def sample_json_schema():
return {
@ -44,9 +45,7 @@ def sample_json_schema():
"type": "array",
"items": {
"type": "string",
"maxLength": 10
},
"minItems": 3
}
},
"work_history": {
"type": "array",
@ -71,8 +70,9 @@ def sample_json_schema():
}
# A schema unsupported by xgrammar
@pytest.fixture
def sample_complex_json_schema():
def unsupported_json_schema():
return {
"type": "object",
"properties": {
@ -150,7 +150,19 @@ def sample_guided_choice():
@pytest.fixture
def sample_sql_statements():
def sample_sql_ebnf():
return """
root ::= select_statement
select_statement ::= "SELECT" column "from" table "where" condition
column ::= "col_1" | "col_2"
table ::= "table_1" | "table_2"
condition ::= column "=" number
number ::= "1" | "2"
"""
@pytest.fixture
def sample_sql_lark():
return ("""
start: select_statement
select_statement: "SELECT" column "from" table "where" condition

View File

View File

@ -0,0 +1,269 @@
# SPDX-License-Identifier: Apache-2.0
import json
import jsonschema
import pytest
from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
GUIDED_DECODING_BACKENDS_V1 = ["xgrammar"]
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
def test_guided_json_completion(monkeypatch, sample_json_schema,
guided_decoding_backend: str):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024)
sampling_params = SamplingParams(temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=sample_json_schema,
backend=guided_decoding_backend))
outputs = llm.generate(prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {sample_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
output_json = json.loads(generated_text)
jsonschema.validate(instance=output_json, schema=sample_json_schema)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
def test_guided_json_object(monkeypatch, guided_decoding_backend: str):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024)
sampling_params = SamplingParams(temperature=1.0,
max_tokens=100,
n=2,
guided_decoding=GuidedDecodingParams(
json_object=True,
backend=guided_decoding_backend))
outputs = llm.generate(
prompts=("Generate a JSON object with curly braces for a person with "
"name and age fields for John Smith who is 31 years old."),
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
for i in range(2):
generated_text = output.outputs[i].text
print(generated_text)
assert generated_text is not None
# Parse to verify it is valid JSON
parsed_json = json.loads(generated_text)
assert isinstance(parsed_json, dict)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
def test_guided_json_unsupported_schema(monkeypatch, unsupported_json_schema,
guided_decoding_backend: str):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024)
sampling_params = SamplingParams(temperature=1.0,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
json=unsupported_json_schema,
backend=guided_decoding_backend))
with pytest.raises(ValueError,
match="The provided JSON schema contains features "
"not supported by xgrammar."):
llm.generate(prompts=[
f"Give an example JSON for an employee profile "
f"that fits this schema: {unsupported_json_schema}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
def test_guided_grammar_ebnf(monkeypatch, sample_sql_ebnf,
guided_decoding_backend: str):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024)
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
grammar=sample_sql_ebnf,
backend=guided_decoding_backend))
outputs = llm.generate(
prompts=("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1"),
sampling_params=sampling_params,
use_tqdm=True,
)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
# remove spaces for comparison b/c we removed them in the grammar
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
" ", "")
assert generated_text.strip() == ground_truth
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
def test_guided_grammar_lark(monkeypatch, sample_sql_lark,
guided_decoding_backend: str):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024)
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
grammar=sample_sql_lark,
backend=guided_decoding_backend))
outputs = llm.generate(
prompts=("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1"),
sampling_params=sampling_params,
use_tqdm=True,
)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
assert generated_text is not None
# use Lark to parse the output, and make sure it's a valid parse tree
from lark import Lark
parser = Lark(sample_sql_lark)
parser.parse(generated_text)
# remove spaces for comparison b/c we removed them in the grammar
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
" ", "")
assert generated_text.strip() == ground_truth
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
def test_guided_grammar_ebnf_invalid(monkeypatch,
guided_decoding_backend: str):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024)
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=1000,
guided_decoding=GuidedDecodingParams(
grammar="not a grammar",
backend=guided_decoding_backend))
with pytest.raises(ValueError,
match="Failed to convert the grammar "
"from Lark to EBNF."):
llm.generate(
prompts=("Generate a sql statement that selects col_1 from "
"table_1 where it is equal to 1"),
sampling_params=sampling_params,
use_tqdm=True,
)
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
def test_guided_regex(monkeypatch, sample_regex, guided_decoding_backend: str):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024)
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(
regex=sample_regex,
backend=guided_decoding_backend))
with pytest.raises(ValueError,
match="Regex guided decoding is not supported."):
llm.generate(prompts=[
f"Give an example IPv4 address with this regex: {sample_regex}"
] * 2,
sampling_params=sampling_params,
use_tqdm=True)
# Once regex is supported --
#assert outputs is not None
#for output in outputs:
# assert output is not None
# assert isinstance(output, RequestOutput)
# prompt = output.prompt
# generated_text = output.outputs[0].text
# print(generated_text)
# assert generated_text is not None
# assert re.fullmatch(sample_regex, generated_text) is not None
# print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
@pytest.mark.skip_global_cleanup
@pytest.mark.parametrize("guided_decoding_backend",
GUIDED_DECODING_BACKENDS_V1)
def test_guided_choice_completion(monkeypatch, sample_guided_choice,
guided_decoding_backend: str):
monkeypatch.setenv("VLLM_USE_V1", "1")
llm = LLM(model=MODEL_NAME, max_model_len=1024)
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
guided_decoding=GuidedDecodingParams(
choice=sample_guided_choice,
backend=guided_decoding_backend))
outputs = llm.generate(
prompts="The best language for type-safe systems programming is ",
sampling_params=sampling_params,
use_tqdm=True)
assert outputs is not None
for output in outputs:
assert output is not None
assert isinstance(output, RequestOutput)
prompt = output.prompt
generated_text = output.outputs[0].text
print(generated_text)
assert generated_text is not None
assert generated_text in sample_guided_choice
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

View File

View File

@ -0,0 +1,196 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
from vllm.v1.structured_output.utils import (
has_xgrammar_unsupported_json_features)
@pytest.fixture
def unsupported_string_schemas():
return [
{
"type": "string",
"pattern": "^[a-zA-Z]+$"
},
{
"type": "string",
"enum": ["active", "inactive", "pending"]
},
{
"type": "string",
"minLength": 1
},
{
"type": "string",
"maxLength": 100
},
{
"type": "string",
"format": "email"
},
]
@pytest.fixture
def unsupported_integer_schemas():
return [
{
"type": "integer",
"minimum": 0
},
{
"type": "integer",
"maximum": 120
},
{
"type": "integer",
"exclusiveMinimum": 120
},
{
"type": "integer",
"exclusiveMaximum": 120
},
{
"type": "integer",
"multipleOf": 120
},
]
@pytest.fixture
def unsupported_number_schemas():
return [
{
"type": "number",
"minimum": 0
},
{
"type": "number",
"maximum": 120
},
{
"type": "number",
"exclusiveMinimum": 120
},
{
"type": "number",
"exclusiveMaximum": 120
},
{
"type": "number",
"multipleOf": 120
},
]
@pytest.fixture
def unsupported_array_schemas():
return [
{
"type": "array",
"uniqueItems": True
},
{
"type": "array",
"contains": {
"type": "string"
}
},
{
"type": "array",
"minContains": 1
},
{
"type": "array",
"maxContains": 5
},
{
"type": "array",
"minItems": 1
},
{
"type": "array",
"maxItems": 10
},
]
@pytest.fixture
def unsupported_object_schemas():
return [
{
"type": "object",
"minProperties": 1
},
{
"type": "object",
"maxProperties": 5
},
{
"type": "object",
"propertyNames": {
"pattern": "^[a-z]+$"
}
},
{
"type": "object",
"patternProperties": {
"^S": {
"type": "string"
}
}
},
]
@pytest.fixture
def supported_schema():
return {
"type": "object",
"properties": {
"name": {
"type": "string"
},
"age": {
"type": "integer"
},
"status": {
"type": "string"
},
"scores": {
"type": "array",
"items": {
"type": "number"
}
},
"address": {
"type": "object",
"properties": {
"street": {
"type": "string"
},
"city": {
"type": "string"
}
}
}
}
}
@pytest.mark.parametrize("schema_type", [
"unsupported_string_schemas", "unsupported_integer_schemas",
"unsupported_number_schemas", "unsupported_array_schemas",
"unsupported_object_schemas"
])
def test_unsupported_json_features_by_type(schema_type, request):
schemas = request.getfixturevalue(schema_type)
for schema in schemas:
assert has_xgrammar_unsupported_json_features(
schema), f"Schema should be unsupported: {schema}"
def test_supported_json_features(supported_schema):
assert not has_xgrammar_unsupported_json_features(
supported_schema), "Schema should be supported"

View File

@ -72,6 +72,8 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
@ -135,6 +137,8 @@ def test_update_states_request_finished(model_runner):
num_common_prefix_blocks=0,
finished_req_ids={req_id},
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
metadata_before = model_runner.input_batch.sampling_metadata
@ -165,6 +169,8 @@ def test_update_states_request_resumed(model_runner):
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
model_runner._update_states(scheduler_output)
@ -190,6 +196,8 @@ def test_update_states_request_resumed(model_runner):
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
metadata_before = model_runner.input_batch.sampling_metadata
@ -221,6 +229,8 @@ def test_update_states_no_changes(model_runner):
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
metadata_before = model_runner.input_batch.sampling_metadata
@ -256,6 +266,8 @@ def test_update_states_request_unscheduled(model_runner):
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None,
)
metadata_before = model_runner._update_states(scheduler_output)

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import argparse
import asyncio
import concurrent
@ -8,6 +10,7 @@ import datetime
import enum
import gc
import getpass
import importlib
import importlib.metadata
import importlib.util
import inspect
@ -23,6 +26,7 @@ import tempfile
import threading
import time
import traceback
import types
import uuid
import warnings
import weakref
@ -982,7 +986,7 @@ def current_stream() -> torch.cuda.Stream:
return _current_stream
def enable_trace_function_call_for_thread(vllm_config: "VllmConfig") -> None:
def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None:
"""Set up function tracing for the current thread,
if enabled via the VLLM_TRACE_FUNCTION environment variable
"""
@ -1977,7 +1981,7 @@ class MemorySnapshot:
self.non_torch_memory = self.cuda_memory - self.torch_memory
self.timestamp = time.time()
def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
def __sub__(self, other: MemorySnapshot) -> MemorySnapshot:
return MemorySnapshot(
torch_peak=self.torch_peak - other.torch_peak,
cuda_memory=self.cuda_memory - other.cuda_memory,
@ -2306,3 +2310,54 @@ def warn_for_unimplemented_methods(cls: type[T]) -> type[T]:
type.__setattr__(cls, '__init__', wrapped_init)
return cls
class LazyLoader(types.ModuleType):
"""
LazyLoader module borrowed from Tensorflow
https://github.com/tensorflow/tensorflow/blob/main/tensorflow/python/util/lazy_loader.py
with a addition of "module caching".
Lazily import a module, mainly to avoid pulling in large dependencies.
Modules such as `xgrammar` might do additional side effects, so we
only want to use this when it is needed, delaying all eager effects
"""
def __init__(
self,
local_name: str,
parent_module_globals: dict[str, Any],
name: str,
):
self._local_name = local_name
self._parent_module_globals = parent_module_globals
self._module: types.ModuleType | None = None
super().__init__(str(name))
def _load(self) -> types.ModuleType:
# Import the target module and insert it into the parent's namespace
try:
module = importlib.import_module(self.__name__)
self._parent_module_globals[self._local_name] = module
# The additional add to sys.modules
# ensures library is actually loaded.
sys.modules[self._local_name] = module
except ModuleNotFoundError as err:
raise err from None
# Update this object's dict so that if someone keeps a
# reference to the LazyLoader, lookups are efficient
# (__getattr__ is only called on lookups that fail).
self.__dict__.update(module.__dict__)
return module
def __getattr__(self, item: Any) -> Any:
if self._module is None:
self._module = self._load()
return getattr(self._module, item)
def __dir__(self) -> list[str]:
if self._module is None:
self._module = self._load()
return dir(self._module)

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import time
from collections import deque
from collections.abc import Iterable
@ -18,6 +20,7 @@ from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
logger = init_logger(__name__)
@ -32,12 +35,14 @@ class Scheduler:
lora_config: Optional[LoRAConfig],
speculative_config: Optional[SpeculativeConfig],
log_stats: bool,
structured_output_manager: StructuredOutputManager,
) -> None:
self.scheduler_config = scheduler_config
self.cache_config = cache_config
self.lora_config = lora_config
self.speculative_config = speculative_config
self.log_stats = log_stats
self.structured_output_manager = structured_output_manager
# Scheduling constraints.
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
@ -97,7 +102,7 @@ class Scheduler:
self.encoder_cache_manager = EncoderCacheManager(
cache_size=encoder_cache_size)
def schedule(self) -> "SchedulerOutput":
def schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
# There's no "decoding phase" nor "prefill phase" in the scheduler.
# Each request just has the num_computed_tokens and
@ -114,6 +119,14 @@ class Scheduler:
scheduled_running_reqs: list[Request] = []
preempted_reqs: list[Request] = []
# NOTE: structured_output_request_ids maps
# a request's (request that uses structured output)
# request_id to the running request index.
# This will helps us determine to slice the grammar bitmask
# and only applies valid mask for requests that
# uses structured decoding.
structured_output_request_ids: dict[str, int] = {}
req_to_new_block_ids: dict[str, list[int]] = {}
num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens
@ -184,6 +197,12 @@ class Scheduler:
# Schedule the request.
scheduled_running_reqs.append(request)
self.scheduled_req_ids.add(request.request_id)
if request.use_structured_output:
# PERF: in case of chunked prefill,
# request might not include any new tokens.
# Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids[request.request_id] = req_index
req_to_new_block_ids[request.request_id] = [
b.block_id for b in new_blocks
]
@ -219,6 +238,10 @@ class Scheduler:
if req.lora_request and req.lora_request.lora_int_id > 0)
assert len(requested_loras) <= self.lora_config.max_loras
# Use a temporary deque to collect requests that need to be skipped
# and put back at the head of the waiting queue later
waiting_for_fsm: deque[Request] = deque()
# Next, schedule the WAITING requests.
if not preempted_reqs:
while self.waiting and token_budget > 0:
@ -227,6 +250,16 @@ class Scheduler:
request = self.waiting[0]
if request.status == RequestStatus.WAITING_FOR_FSM:
structured_output_req = request.structured_output_request
if structured_output_req and structured_output_req.grammar:
request.status = RequestStatus.WAITING
else:
waiting_structured_output_req = self.waiting.popleft()
waiting_for_fsm.appendleft(
waiting_structured_output_req)
continue
# Check that adding the request still respects the max_loras
# constraint.
if self.lora_config and request.lora_request:
@ -281,6 +314,10 @@ class Scheduler:
break
self.waiting.popleft()
if request.use_structured_output:
structured_output_request_ids[
request.request_id] = req_index
req_index += 1
self.running.append(request)
self.scheduled_req_ids.add(request.request_id)
self.request_scheduled(request, scheduled_timestamp)
@ -311,6 +348,10 @@ class Scheduler:
self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget
# Put back any skipped requests at the head of the waiting queue
if waiting_for_fsm:
self.waiting.extendleft(waiting_for_fsm)
# Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
@ -331,6 +372,11 @@ class Scheduler:
self.kv_cache_manager.get_num_common_prefix_blocks(
any_request, len(self.running)))
grammar_bitmask = self.structured_output_manager.grammar_bitmask(
self.requests,
structured_output_request_ids,
len(self.running),
)
# Construct the scheduler output.
new_reqs_data = [
NewRequestData.from_request(req,
@ -369,6 +415,8 @@ class Scheduler:
# the previous and the current steps.
finished_req_ids=self.finished_req_ids,
free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(),
structured_output_request_ids=structured_output_request_ids,
grammar_bitmask=grammar_bitmask,
)
self.finished_req_ids = set()
@ -381,7 +429,7 @@ class Scheduler:
num_scheduled_spec_tokens: int,
new_block_ids: list[int],
resumed_from_preemption: bool,
) -> "CachedRequestData":
) -> CachedRequestData:
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
# them at each scheduling step.
num_computed_tokens = request.num_computed_tokens
@ -474,8 +522,8 @@ class Scheduler:
def update_from_output(
self,
scheduler_output: "SchedulerOutput",
model_runner_output: "ModelRunnerOutput",
scheduler_output: SchedulerOutput,
model_runner_output: ModelRunnerOutput,
) -> EngineCoreOutputs:
sampled_token_ids = model_runner_output.sampled_token_ids
spec_token_ids = model_runner_output.spec_token_ids
@ -565,6 +613,15 @@ class Scheduler:
# the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1)
if new_token_ids and request.use_structured_output:
# NOTE: structured_output_request
# should not be None if use_structured_output, we have
# check above, so safe to ignore type warning
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
request.request_id,
new_token_ids,
)
# Transmit partial if chunked prefill & prompt logprobs is enabled
if new_token_ids or prompt_logprobs_tensors is not None:
# Add EngineCoreOutput for this Request.

View File

@ -1,9 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.base import PlaceholderRange
@ -17,20 +22,20 @@ class NewRequestData:
req_id: str
prompt_token_ids: list[int]
prompt: Optional[str]
mm_inputs: list["MultiModalKwargs"]
mm_inputs: list[MultiModalKwargs]
mm_hashes: list[str]
mm_positions: list["PlaceholderRange"]
sampling_params: "SamplingParams"
mm_positions: list[PlaceholderRange]
sampling_params: SamplingParams
block_ids: list[int]
num_computed_tokens: int
lora_request: Optional["LoRARequest"]
lora_request: Optional[LoRARequest]
@classmethod
def from_request(
cls,
request: "Request",
request: Request,
block_ids: list[int],
) -> "NewRequestData":
) -> NewRequestData:
return cls(
req_id=request.request_id,
prompt_token_ids=request.prompt_token_ids,
@ -60,11 +65,11 @@ class CachedRequestData:
@classmethod
def from_request(
cls,
request: "Request",
request: Request,
resumed_from_preemption: bool,
new_token_ids: list[int],
new_block_ids: list[int],
) -> "CachedRequestData":
) -> CachedRequestData:
return cls(
req_id=request.request_id,
resumed_from_preemption=resumed_from_preemption,
@ -111,3 +116,9 @@ class SchedulerOutput:
# list of (req_id, encoder_input_index) tuples.
# Used to free the encoder cache.
free_encoder_input_ids: list[tuple[str, int]]
# Dict of request ids to their index within the batch
# for filling the next token bitmask
structured_output_request_ids: dict[str, int]
# the bitmask for the whole batch
grammar_bitmask: Optional[npt.NDArray[np.int32]]

View File

@ -72,9 +72,7 @@ class AsyncLLM(EngineClient):
# Processor (converts Inputs --> EngineCoreRequests).
self.processor = Processor(
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
vllm_config=vllm_config,
tokenizer=self.tokenizer,
input_registry=input_registry,
)

View File

@ -29,6 +29,7 @@ from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.structured_output import StructuredOutputManager
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
@ -61,6 +62,8 @@ class EngineCore:
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
self.structured_output_manager = StructuredOutputManager(vllm_config)
# Setup scheduler.
self.scheduler = Scheduler(
scheduler_config=vllm_config.scheduler_config,
@ -69,6 +72,7 @@ class EngineCore:
lora_config=vllm_config.lora_config,
speculative_config=vllm_config.speculative_config,
log_stats=self.log_stats,
structured_output_manager=self.structured_output_manager,
)
# Setup MM Input Mapper.
@ -131,6 +135,9 @@ class EngineCore:
request.mm_inputs, request.mm_hashes)
req = Request.from_engine_core_request(request)
if req.use_structured_output:
# Start grammar compilation asynchronously
self.structured_output_manager.populate_cache(req)
self.scheduler.add_request(req)
@ -148,11 +155,24 @@ class EngineCore:
if not self.scheduler.has_unfinished_requests():
return EngineCoreOutputs(
outputs=[], scheduler_stats=self.scheduler.make_stats())
outputs=[],
scheduler_stats=self.scheduler.make_stats(),
)
scheduler_output = self.scheduler.schedule()
# This case may occur when the only unfinished requests are
# structured output requests where the grammar has not finished
# compiling yet, so there's nothing to run.
if scheduler_output.total_num_scheduled_tokens == 0:
return EngineCoreOutputs(
outputs=[],
scheduler_stats=self.scheduler.make_stats(),
)
output = self.model_executor.execute_model(scheduler_output)
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, output) # type: ignore
return engine_core_outputs
def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]:

View File

@ -66,9 +66,7 @@ class LLMEngine:
self.tokenizer.ping()
# Processor (convert Inputs --> EngineCoreRequests)
self.processor = Processor(model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
lora_config=vllm_config.lora_config,
self.processor = Processor(vllm_config=vllm_config,
tokenizer=self.tokenizer,
input_registry=input_registry,
mm_registry=mm_registry)

View File

@ -4,7 +4,7 @@ import time
from collections.abc import Mapping
from typing import Optional, Union
from vllm.config import CacheConfig, LoRAConfig, ModelConfig
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputsAdapter)
from vllm.inputs.parse import is_encoder_decoder_inputs
@ -19,39 +19,41 @@ from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_cache import MMInputCacheClient
from vllm.v1.structured_output.utils import validate_structured_output_request
class Processor:
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
vllm_config: VllmConfig,
tokenizer: BaseTokenizerGroup,
input_registry: InputRegistry = INPUT_REGISTRY,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
):
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.decoding_config = vllm_config.decoding_config
self.tokenizer = tokenizer
self.generation_config_fields = model_config.try_get_generation_config(
)
self.input_preprocessor = InputPreprocessor(model_config,
self.generation_config_fields = (
self.model_config.try_get_generation_config())
self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer,
mm_registry)
self.input_processor = input_registry.create_input_processor(
model_config)
self.model_config)
# Multi-modal (huggingface) input mapper
self.mm_input_cache_client = MMInputCacheClient(model_config)
self.mm_input_cache_client = MMInputCacheClient(self.model_config)
# Multi-modal hasher (for images)
self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
cache_config.enable_prefix_caching
self.use_hash = (
not self.model_config.disable_mm_preprocessor_cache) or \
self.cache_config.enable_prefix_caching
def _validate_logprobs(
self,
@ -80,6 +82,8 @@ class Processor:
self,
params: SamplingParams,
) -> None:
self._validate_structured_output(params)
if params.allowed_token_ids is None:
return
if not params.allowed_token_ids:
@ -125,6 +129,21 @@ class Processor:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!")
def _validate_structured_output(self, params: SamplingParams) -> None:
if not params.guided_decoding or not self.decoding_config:
return
if self.decoding_config.guided_decoding_backend != "xgrammar":
raise ValueError(
"Only xgrammar structured output is supported in V1.")
if (params.guided_decoding.backend
and params.guided_decoding.backend != 'xgrammar'):
raise ValueError(
"Only xgrammar structured output is supported in V1.")
if self.vllm_config.speculative_config:
raise ValueError("Structured output is not supported with "
"speculative decoding.")
validate_structured_output_request(params)
def process_inputs(
self,
request_id: str,

View File

@ -3,13 +3,15 @@
import enum
from typing import TYPE_CHECKING, Optional, Union
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
EngineCoreRequest, FinishReason)
from vllm.v1.structured_output.request import StructuredOutputRequest
from vllm.v1.utils import ConstantList
if TYPE_CHECKING:
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
@ -27,15 +29,19 @@ class Request:
sampling_params: SamplingParams,
eos_token_id: Optional[int],
arrival_time: float,
lora_request: Optional[LoRARequest] = None,
lora_request: Optional["LoRARequest"] = None,
structured_output_request: Optional["StructuredOutputRequest"] = None,
) -> None:
self.request_id = request_id
self.sampling_params = sampling_params
# Because of LoRA, the eos token id can be different for each request.
self.eos_token_id = eos_token_id
self.lora_request = lora_request
self.structured_output_request = structured_output_request
self.status = RequestStatus.WAITING
self.status = (RequestStatus.WAITING_FOR_FSM
if sampling_params.guided_decoding is not None else
RequestStatus.WAITING)
self.events: list[EngineCoreEvent] = []
self.stop_reason: Union[int, str, None] = None
assert sampling_params.max_tokens is not None
@ -78,6 +84,8 @@ class Request:
eos_token_id=request.eos_token_id,
arrival_time=request.arrival_time,
lora_request=request.lora_request,
structured_output_request=StructuredOutputRequest(
sampling_params=request.sampling_params),
)
def queued(self, timestamp: Optional[float] = None) -> None:
@ -134,18 +142,23 @@ class Request:
num_tokens = self.mm_positions[input_id]["length"]
return num_tokens
@property
def use_structured_output(self) -> bool:
return self.sampling_params.guided_decoding is not None
class RequestStatus(enum.IntEnum):
"""Status of a request."""
WAITING = 0
RUNNING = 1
PREEMPTED = 2
# Note: anything after PREEMPTED (2) will be considered
WAITING = enum.auto()
WAITING_FOR_FSM = enum.auto()
RUNNING = enum.auto()
PREEMPTED = enum.auto()
# Note: anything after PREEMPTED will be considered
# as a finished status.
FINISHED_STOPPED = 3
FINISHED_LENGTH_CAPPED = 4
FINISHED_ABORTED = 5
FINISHED_IGNORED = 6
FINISHED_STOPPED = enum.auto()
FINISHED_LENGTH_CAPPED = enum.auto()
FINISHED_ABORTED = enum.auto()
FINISHED_IGNORED = enum.auto()
@staticmethod
def is_finished(status: "RequestStatus") -> bool:

View File

@ -0,0 +1,152 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import copy
import multiprocessing
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Optional
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import LazyLoader
from vllm.v1.structured_output.grammar import (Grammar, StructuredOutputKey,
StructuredOutputOptions)
if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
import xgrammar as xgr
from vllm.v1.request import Request
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
logger = init_logger(__name__)
class StructuredOutputManager:
def __init__(self, vllm_config: VllmConfig, max_cache_size: int = 500):
tokenizer_group = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
parallel_config=vllm_config.parallel_config,
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
tokenizer_group.ping()
self.vocab_size = vllm_config.model_config.get_vocab_size()
self.vllm_config = vllm_config
tokenizer = tokenizer_group.get_lora_tokenizer(None)
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
tokenizer, vocab_size=self.vocab_size)
self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8)
self.max_cache_size = max_cache_size
self.request_key_to_grammar: OrderedDict[StructuredOutputKey,
Grammar] = OrderedDict()
# The default max_workers if not specified is the number of CPUs * 5,
# which is way too high since these tasks are CPU-bound, not I/O bound.
# We also know we would never dominate CPU usage with just grammar
# compilation, so we set it to half the number of CPUs.
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self._grammar_bitmask = xgr.allocate_token_bitmask(
self.vllm_config.scheduler_config.max_num_seqs, self.vocab_size)
def __getitem__(self, key: StructuredOutputKey) -> Optional[Grammar]:
# We need to pop and re-insert the grammar here for LRU cache
# of request_key_to_grammar
if key in self.request_key_to_grammar:
# Move accessed item to the end (most recently used)
value = self.request_key_to_grammar.pop(key)
if value is not None:
self.request_key_to_grammar[key] = value
return value
return None
def populate_cache(self, request: Request) -> None:
if request.structured_output_request is None:
return
grammar = self.request_key_to_grammar.get(
request.structured_output_request.structured_output_key)
if grammar:
request.structured_output_request.grammar = copy.copy(grammar)
return
request.structured_output_request.grammar = self.cache(request)
def cache(self, request: Request):
return self.executor.submit(self._executor_loop, request)
def _executor_loop(self, request: Request) -> Grammar:
# NOTE: The structured_output_request should never be
# None in this case, but mypy can't infer this
# correctly, so we need to ignore the error here.
key = request.structured_output_request.structured_output_key # type: ignore[union-attr]
grammar = self.request_key_to_grammar.get(key)
if grammar is not None:
return copy.copy(grammar)
grammar = self.initialize_grammar(key)
# If cache is full, remove the least recently used item
if len(self.request_key_to_grammar) >= self.max_cache_size:
self.request_key_to_grammar.popitem(last=False)
self.request_key_to_grammar[key] = grammar
return copy.copy(grammar)
def initialize_grammar(self, key: StructuredOutputKey) -> Grammar:
# Note that the request was validated in the engine core client,
# so at this point we know it is a supported type of request.
#
# TODO: we still need to handle xgrammar compilation failures
request_type, grammar_spec = key
if request_type == StructuredOutputOptions.JSON:
# TODO -- allow any_whitespace to be configurable
# pending merge of https://github.com/vllm-project/vllm/pull/12744
ctx = self.compiler.compile_json_schema(grammar_spec,
any_whitespace=False)
elif request_type == StructuredOutputOptions.JSON_OBJECT:
ctx = self.compiler.compile_builtin_json_grammar()
elif request_type == StructuredOutputOptions.GRAMMAR:
ctx = self.compiler.compile_grammar(grammar_spec)
else:
logger.error("Validation should have already occurred. "
"Please file an issue.")
raise ValueError(
f"grammar is not of valid supported types. ({request_type!s})")
return Grammar(
matcher=xgr.GrammarMatcher(ctx),
vocab_size=self.vocab_size,
ctx=ctx,
)
def grammar_bitmask(
self,
requests: dict[str, Request],
structured_output_request_ids: dict[str, int],
batch_len: int,
) -> Optional[npt.NDArray[np.int32]]:
# Prepare the structured output bitmask for this batch.
if not structured_output_request_ids:
return None
# Fill the bitmask using the index of each request equal to its
# position in the batch. Resize the bitmask down to the size of
# the batch.
bitmask_tensor = self._grammar_bitmask
for req_id, batch_index in structured_output_request_ids.items():
request = requests[req_id].structured_output_request
assert request is not None and request.grammar is not None
if not request.grammar.matcher.is_terminated():
request.grammar.fill_bitmask(bitmask_tensor, batch_index)
if batch_len < self._grammar_bitmask.shape[0]:
bitmask_tensor = self._grammar_bitmask[:batch_len]
# After finishing with the xgrammar operations, we convert to
# np.ndarray, because that is much more efficient for serialization
# and deserialization when sending this to the GPU workers.
return bitmask_tensor.numpy()

View File

@ -0,0 +1,77 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import enum
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import torch
from vllm.logger import init_logger
from vllm.utils import LazyLoader
if TYPE_CHECKING:
import xgrammar as xgr
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
logger = init_logger(__name__)
class StructuredOutputOptions(enum.Enum):
JSON = enum.auto()
JSON_OBJECT = enum.auto()
REGEX = enum.auto()
GRAMMAR = enum.auto()
CHOICE = enum.auto()
StructuredOutputKey = tuple[StructuredOutputOptions, str]
@dataclass
class Grammar:
# NOTE: This would be a generic-enough class for
# supporting different backends, in the future.
# For now, just xgrammar.
#
# TODO: support max_rollback_tokens
# https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string
# for jump-forward decoding
vocab_size: int
matcher: xgr.GrammarMatcher = field(hash=False)
ctx: xgr.CompiledGrammar = field(hash=False)
num_processed_tokens: int = field(default_factory=lambda: 0,
repr=False,
hash=False,
init=False)
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
"""Accepts a list of tokens and advances the FSM.
Returns True if the FSM was advanced successfully.
Returns False if the FSM failed to advance.
"""
for token in tokens:
if not self.matcher.accept_token(token):
logger.error(
"Failed to advance FSM for request %s "
"for tokens %s. Please file an issue.", request_id, token)
return False
self.num_processed_tokens += 1
return True
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> bool:
return self.matcher.fill_next_token_bitmask(bitmask, idx)
def reset(self):
self.num_processed_tokens = 0
self.matcher.reset()
def __copy__(self):
return Grammar(
matcher=xgr.GrammarMatcher(self.ctx),
vocab_size=self.vocab_size,
ctx=self.ctx,
)

View File

@ -0,0 +1,71 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import dataclasses
import functools
import json
from concurrent.futures import Future
from concurrent.futures._base import TimeoutError
from typing import Optional, Union, cast
from vllm.sampling_params import SamplingParams
from vllm.v1.structured_output.grammar import (Grammar, StructuredOutputKey,
StructuredOutputOptions)
@dataclasses.dataclass
class StructuredOutputRequest:
sampling_params: SamplingParams
_grammar: Optional[Union[Future[Grammar], Grammar]] = None
def _check_grammar_completion(self) -> bool:
# NOTE: We have to lazy import to gate circular imports
from vllm.v1.request import RequestStatus
if isinstance(self._grammar, Future):
try:
# We will check whether the future is ready within 100 us
self._grammar = self._grammar.result(timeout=0.0001)
self.status = RequestStatus.WAITING
except TimeoutError:
return False
return True
@property
def is_grammar_ready(self) -> bool:
return self._check_grammar_completion()
@property
def grammar(self) -> Optional[Grammar]:
completed = self._check_grammar_completion()
return cast(Optional[Grammar], self._grammar) if completed else None
@grammar.setter
def grammar(self, grammar: Union[Grammar, Future[Grammar]]) -> None:
self._grammar = grammar
@functools.cached_property
def structured_output_key(self) -> StructuredOutputKey:
params = self.sampling_params.guided_decoding
assert params is not None, "params can't be None."
if params.json is not None:
if not isinstance(params.json, str):
json_str = json.dumps(params.json)
else:
json_str = params.json
return (StructuredOutputOptions.JSON, json_str)
elif params.json_object:
return (StructuredOutputOptions.JSON_OBJECT, "")
elif params.regex is not None:
return (StructuredOutputOptions.REGEX, params.regex)
elif params.choice is not None:
if not isinstance(params.choice, str):
json_str = json.dumps(params.choice)
else:
json_str = params.choice
return (StructuredOutputOptions.CHOICE, json_str)
elif params.grammar is not None:
return (StructuredOutputOptions.GRAMMAR, params.grammar)
else:
raise ValueError("No valid structured output parameter found")

View File

@ -0,0 +1,295 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import json
import re
from typing import TYPE_CHECKING, Any
from vllm.sampling_params import SamplingParams
from vllm.utils import LazyLoader
if TYPE_CHECKING:
import xgrammar as xgr
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool:
"""Check if JSON schema contains features unsupported by xgrammar."""
def check_object(obj: dict[str, Any]) -> bool:
if not isinstance(obj, dict):
return False
# Check for pattern restrictions
if "pattern" in obj:
return True
# Check for enum restrictions
if "enum" in obj:
return True
# Check for numeric ranges
if obj.get("type") in ("integer", "number") and any(
key in obj
for key in ("minimum", "maximum", "exclusiveMinimum",
"exclusiveMaximum", "multipleOf")):
return True
# Check for array unsupported keywords
if obj.get("type") == "array" and any(
key in obj
for key in ("uniqueItems", "contains", "minContains",
"maxContains", "minItems", "maxItems")):
return True
# Unsupported keywords for strings
if obj.get("type") == "string" and any(
key in obj for key in ("minLength", "maxLength", "format")):
return True
# Unsupported keywords for objects
if obj.get("type") == "object" and any(
key in obj for key in ("minProperties", "maxProperties",
"propertyNames", "patternProperties")):
return True
# Recursively check all nested objects and arrays
for value in obj.values():
if isinstance(value, dict):
if check_object(value):
return True
elif isinstance(value, list):
for item in value:
if isinstance(item, dict) and check_object(item):
return True
return False
return check_object(schema)
def grammar_is_likely_lark(grammar_str: str) -> bool:
"""
Check if grammar appears to use Lark syntax.
Args:
grammar_str: Input grammar string
Returns:
bool: True if grammar appears to be in Lark format, False otherwise
Examples:
>>> grammar_is_likely_lark("rule: 'abc'")
True
>>> grammar_is_likely_lark("rule ::= 'abc'")
False
"""
if not grammar_str or not isinstance(grammar_str, str):
return False
for line in grammar_str.split('\n'):
# Remove both comment styles
line = re.sub(r'(#|//).*$', '', line).strip()
if not line:
continue
# Look for EBNF rule definition
if '::=' in line:
return False
return True
def convert_lark_to_ebnf(grammar_str: str) -> str:
"""
Convert a Lark grammar string to EBNF format.
EBNF reference:
https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
Lark grammar reference:
https://lark-parser.readthedocs.io/en/latest/grammar.html
Args:
grammar_str: Input grammar in Lark format
Returns:
str: Converted grammar in EBNF format
Examples:
>>> print(convert_lark_to_ebnf("rule: 'hello'"))
root ::= rule
rule ::= "hello"
"""
if not isinstance(grammar_str, str):
raise ValueError(f"Grammar must be a string, got {type(grammar_str)}")
if not grammar_str.strip():
raise ValueError("Grammar string cannot be empty")
defined_rules = set()
referenced_rules = set()
output_lines = []
def clean_line(line: str) -> str:
"""Remove comments and whitespace from line."""
return re.sub(r'(#|//).*$', '', line).strip()
def check_quotes(text: str, rule_name: str, line_num: int) -> None:
"""Validate quote matching in text."""
if text.count("'") % 2 != 0 or text.count('"') % 2 != 0:
raise ValueError(
f"Mismatched quotes in {rule_name} on line {line_num}")
def extract_references(text: str) -> set:
"""Extract rule references from text."""
# Remove quoted strings and special characters
text = re.sub(r'"[^"]*"', '', text)
text = re.sub(r'[+*?()|\[\]{}]', ' ', text)
return set(re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', text))
# First pass: Find root rule and validate rule definitions
lines = [clean_line(line) for line in grammar_str.split('\n')]
first_rule = None
for line_num, line in enumerate(lines, 1):
if not line or line.startswith('|'):
continue
if ':' in line:
try:
name = line.split(':', 1)[0].strip().strip('?')
defined_rules.add(name)
if first_rule is None:
first_rule = name
if name == 'start':
first_rule = 'start'
except IndexError as e:
raise ValueError(f"Invalid rule format on line {line_num}. "
"Expected 'rule_name: definition'") from e
if not defined_rules:
raise ValueError("No valid rules found in grammar")
# Add root rule
output_lines.append(f"root ::= {first_rule}")
# Second pass: Process rule definitions and alternatives
current_rule = None
current_definition = []
for line_num, line in enumerate(lines, 1):
if not line:
continue
try:
if ':' in line and not line.startswith('|'):
# Save previous rule if exists
if current_rule:
output_lines.append(
f"{current_rule} ::= {' | '.join(current_definition)}")
# Process new rule
name, definition = line.split(':', 1)
current_rule = name.strip().strip('?')
check_quotes(definition, f"rule '{current_rule}'", line_num)
definition = re.sub(r"'([^']*)'", r'"\1"', definition)
referenced_rules.update(extract_references(definition))
current_definition = [definition.strip()]
elif line.startswith('|'):
if not current_rule:
raise ValueError(f"Alternative '|' on line {line_num} "
"without a preceding rule definition")
alt_def = line[1:].strip()
check_quotes(alt_def, f"alternative for rule '{current_rule}'",
line_num)
alt_def = re.sub(r"'([^']*)'", r'"\1"', alt_def)
referenced_rules.update(extract_references(alt_def))
current_definition.append(alt_def)
except ValueError as e:
raise ValueError(f"Error on line {line_num}: {str(e)}") from e
# Add final rule if exists
if current_rule:
output_lines.append(
f"{current_rule} ::= {' | '.join(current_definition)}")
# Validate all rules are defined
undefined_rules = referenced_rules - defined_rules - {'root'}
if undefined_rules:
raise ValueError("Referenced rules are not defined: "
f"{', '.join(sorted(undefined_rules))}")
return '\n'.join(output_lines)
def choice_as_grammar(choice: list[str]) -> str:
def escape_ebnf_string(s: str) -> str:
"""Escape special characters in a EBNF string."""
# Escape double quotes and backslashes
return re.sub(r'(["\\])', r'\\\1', s)
escaped_choices = (escape_ebnf_string(c) for c in choice)
grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices))
return grammar
def validate_structured_output_request(
sampling_params: SamplingParams) -> None:
"""Validate that the request is supported by structured output.
Raises ValueError if the request is not supported.
"""
if sampling_params.guided_decoding is None:
return
gd_params = sampling_params.guided_decoding
if gd_params.regex:
raise ValueError("Regex structured output is not supported.")
if gd_params.choice:
choice_grammar = choice_as_grammar(gd_params.choice)
try:
xgr.Grammar.from_ebnf(choice_grammar)
except Exception as err:
raise ValueError("Failed to transform choices into a grammar: "
"{err}") from err
gd_params.choice = None
gd_params.grammar = choice_grammar
return
if gd_params.json:
if isinstance(gd_params.json, str):
try:
schema = json.loads(gd_params.json)
except json.JSONDecodeError as e:
raise ValueError("Invalid JSON grammar specification.") from e
else:
schema = gd_params.json
if has_xgrammar_unsupported_json_features(schema):
raise ValueError("The provided JSON schema contains features not "
"supported by xgrammar.")
return
if gd_params.grammar:
if grammar_is_likely_lark(gd_params.grammar):
# xgrammar supports EBNF grammars only
try:
gd_params.grammar = convert_lark_to_ebnf(gd_params.grammar)
except ValueError as e:
raise ValueError(
"Failed to convert the grammar from Lark to EBNF. ") from e
# Test parsing EBNF grammar, possibly already converted from Lark
try:
# parse the grammar, but we aren't compiling it.
xgr.Grammar.from_ebnf(gd_params.grammar)
except Exception as e:
raise ValueError("Invalid grammar specification.") from e

View File

@ -25,7 +25,8 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, cdiv, is_pin_memory_available)
LayerBlockType, LazyLoader, cdiv,
is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.engine.mm_input_cache import MMInputCacheClient
@ -40,7 +41,11 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
if TYPE_CHECKING:
import xgrammar as xgr
from vllm.v1.core.scheduler_output import SchedulerOutput
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
logger = init_logger(__name__)
@ -860,6 +865,53 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def get_model(self) -> nn.Module:
return self.model
def apply_grammar_bitmask(
self,
scheduler_output: "SchedulerOutput",
logits: torch.Tensor,
):
# Serialization of np.ndarray is much more efficient than a tensor,
# so we receive it in that format.
grammar_bitmask = scheduler_output.grammar_bitmask
if grammar_bitmask is None:
return
# We receive the structured output bitmask from the scheduler, but the
# indices of the requests in the batch may not match the indices of
# the bitmask since the scheduler doesn't know how the gpu runner is
# ordering the requests in the batch. We need to sort the bitmask to
# match the order of the requests used here.
struct_out_req_batch_indices: dict[str, int] = {}
indices_match = True
for req_id in self.input_batch.req_ids:
mask_index = scheduler_output.structured_output_request_ids.get(
req_id)
if mask_index is None:
# not a structured output request
continue
batch_index = self.input_batch.req_id_to_index[req_id]
if batch_index != mask_index:
indices_match = False
struct_out_req_batch_indices[req_id] = batch_index
if not indices_match:
# Sort the bitmask to match the order of the requests
sorted_bitmask = np.zeros_like(grammar_bitmask)
for req_id, batch_index in struct_out_req_batch_indices.items():
orig_index = scheduler_output.structured_output_request_ids[
req_id]
sorted_bitmask[batch_index] = grammar_bitmask[orig_index]
grammar_bitmask = sorted_bitmask
grammar_bitmask = torch.from_numpy(grammar_bitmask)
# TODO: compatibility with spec decode
xgr.apply_token_bitmask_inplace(
logits,
grammar_bitmask.to(self.device, non_blocking=True),
indices=list(struct_out_req_batch_indices.values()),
)
@torch.inference_mode()
def execute_model(
self,
@ -945,6 +997,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
# Apply structured output bitmasks if present
if scheduler_output.grammar_bitmask is not None:
self.apply_grammar_bitmask(scheduler_output, logits)
# Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.sampling_metadata
if not self.use_spec_decode: