[V1][Core] Support for Structured Outputs (#12388)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Signed-off-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
1e3598edeb
commit
80e9afb5bc
@ -204,6 +204,7 @@ steps:
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/engine
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/sample
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/worker
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/structured_output
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/test_stats.py
|
||||
- VLLM_USE_V1=1 pytest -v -s v1/test_utils.py
|
||||
# TODO: accuracy does not match, whether setting
|
||||
|
2
.gitignore
vendored
2
.gitignore
vendored
@ -197,7 +197,7 @@ _build/
|
||||
hip_compat.h
|
||||
|
||||
# Benchmark dataset
|
||||
benchmarks/*.json
|
||||
benchmarks/**/*.json
|
||||
|
||||
# Linting
|
||||
actionlint
|
||||
|
@ -1,507 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Benchmark guided decoding throughput."""
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
|
||||
import datasets
|
||||
import pandas as pd
|
||||
import uvloop
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.entrypoints.openai.api_server import (
|
||||
build_async_engine_client_from_engine_args)
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SampleRequest:
|
||||
"""A class representing a single inference request for benchmarking.
|
||||
|
||||
Attributes:
|
||||
prompt: The input text prompt for the model.
|
||||
multi_modal_data: Optional dictionary containing multi-modal data (e.g.
|
||||
images).
|
||||
prompt_len: The length of the prompt in tokens.
|
||||
expected_output_len: The expected length of the output in tokens.
|
||||
"""
|
||||
prompt: str
|
||||
prompt_len: int
|
||||
expected_output_len: int
|
||||
schema: dict
|
||||
structure_type: str = 'json'
|
||||
completion: str = None
|
||||
|
||||
|
||||
def run_vllm(requests: list[SampleRequest],
|
||||
engine_args: EngineArgs,
|
||||
n: int,
|
||||
guided_decoding_rate: float = 1.0,
|
||||
warmup: bool = False) -> float:
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(**vars(engine_args))
|
||||
assert all(
|
||||
llm.llm_engine.model_config.max_model_len >= (
|
||||
request.prompt_len + request.expected_output_len)
|
||||
for request in requests), (
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" prompt_len and expected_output_len for all requests.")
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts: list[str] = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
# create a list containing random selected true or false
|
||||
guided_decoding_req_idx = random.sample(
|
||||
range(len(requests)), int(len(requests) * guided_decoding_rate))
|
||||
|
||||
if warmup:
|
||||
print(">>>>> Running warmup prompt, for the first 5")
|
||||
# We setup the first 5 requests to warmup FSM
|
||||
# if using xgrammar dataset, we will skip warmup
|
||||
warmup_requests = requests[:5]
|
||||
for i, request in enumerate(warmup_requests):
|
||||
prompts.append(request.prompt)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
guided_decoding=GuidedDecodingParams(json=request.schema)
|
||||
if guided_decoding_rate > 0 else None,
|
||||
))
|
||||
llm.generate(prompts, sampling_params, use_tqdm=False)
|
||||
|
||||
print(">>>>> Benchmark started...")
|
||||
prompts = []
|
||||
sampling_params = []
|
||||
for i, request in enumerate(requests):
|
||||
prompts.append(request.prompt)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
**{request.structure_type: request.schema})
|
||||
if i in guided_decoding_req_idx else None,
|
||||
))
|
||||
|
||||
start = time.perf_counter()
|
||||
outputs = llm.generate(prompts, sampling_params, use_tqdm=False)
|
||||
ret = []
|
||||
for output, request in zip(outputs, requests):
|
||||
generated_text = output.outputs[0].text
|
||||
ret.append({
|
||||
"generated": generated_text,
|
||||
"expected": request.completion
|
||||
})
|
||||
end = time.perf_counter()
|
||||
return end - start, ret
|
||||
|
||||
|
||||
async def run_vllm_async(
|
||||
requests: list[SampleRequest],
|
||||
engine_args: AsyncEngineArgs,
|
||||
n: int,
|
||||
guided_decoding_rate: float = 1.0,
|
||||
warmup: bool = False,
|
||||
disable_frontend_multiprocessing: bool = False) -> float:
|
||||
from vllm import SamplingParams
|
||||
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args, disable_frontend_multiprocessing) as llm:
|
||||
|
||||
assert all(
|
||||
llm.model_config.max_model_len >= (request.prompt_len +
|
||||
request.expected_output_len)
|
||||
for request in requests), (
|
||||
"Please ensure that max_model_len is greater than the sum of"
|
||||
" prompt_len and expected_output_len for all requests.")
|
||||
|
||||
# Add the requests to the engine.
|
||||
prompts: list[str] = []
|
||||
sampling_params: list[SamplingParams] = []
|
||||
guided_decoding_req_idx = random.sample(
|
||||
range(len(requests)), int(len(requests) * guided_decoding_rate))
|
||||
|
||||
if warmup:
|
||||
print(">>>>>> Running warmup prompt, for the first 5")
|
||||
# We setup the first 5 requests to warmup FSM
|
||||
# if using xgrammar dataset, we will skip warmup
|
||||
warmup_requests = requests[:5]
|
||||
for i, request in enumerate(warmup_requests):
|
||||
prompts.append(request.prompt)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=request.schema)
|
||||
if guided_decoding_rate > 0 else None,
|
||||
))
|
||||
generators = []
|
||||
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
|
||||
generator = llm.generate(prompt, sp, request_id=f"test{i}")
|
||||
generators.append(generator)
|
||||
all_gens = merge_async_iterators(*generators)
|
||||
async for i, res in all_gens:
|
||||
pass
|
||||
|
||||
print(">>>>> Benchmark started...")
|
||||
prompts = []
|
||||
sampling_params = []
|
||||
for i, request in enumerate(requests):
|
||||
prompts.append(request.prompt)
|
||||
sampling_params.append(
|
||||
SamplingParams(
|
||||
n=n,
|
||||
temperature=1.0,
|
||||
top_p=1.0,
|
||||
ignore_eos=True,
|
||||
max_tokens=request.expected_output_len,
|
||||
guided_decoding=GuidedDecodingParams(json=request.schema)
|
||||
if i in guided_decoding_req_idx else None,
|
||||
))
|
||||
|
||||
generators = []
|
||||
start_time = []
|
||||
latencies = []
|
||||
start = time.perf_counter()
|
||||
for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)):
|
||||
generator = llm.generate(prompt, sp, request_id=f"test{i}")
|
||||
generators.append(generator)
|
||||
start_time.append(time.perf_counter())
|
||||
latencies.append([])
|
||||
all_gens = merge_async_iterators(*generators)
|
||||
generated_texts = [''] * len(requests)
|
||||
async for i, res in all_gens:
|
||||
generated_texts[i] = res.outputs[0].text
|
||||
lat = time.perf_counter() - start_time[i]
|
||||
latencies[i].append(lat)
|
||||
ret = [{
|
||||
'generated': gt,
|
||||
'expected': req.completion
|
||||
} for gt, req in zip(generated_texts, requests)]
|
||||
end = time.perf_counter()
|
||||
first_latency = pd.Series([lat[0] * 1000 for lat in latencies])
|
||||
next_latency = pd.Series([(lat[-1] - lat[0]) / len(lat[1:]) * 1000
|
||||
for lat in latencies])
|
||||
return end - start, ret, (first_latency, next_latency)
|
||||
|
||||
|
||||
def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
args: argparse.Namespace) -> list[SampleRequest]:
|
||||
if args.dataset == 'json':
|
||||
if args.json_schema_path is None:
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
args.json_schema_path = os.path.join(dir_path,
|
||||
"structured_schemas",
|
||||
"structured_schema_1.json")
|
||||
with open(args.json_schema_path) as f:
|
||||
schema = json.load(f)
|
||||
prompt = f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501
|
||||
input_len = len(tokenizer(prompt).input_ids)
|
||||
print(f"Input length of the prompt: {input_len} tokens")
|
||||
requests = [
|
||||
SampleRequest(prompt=prompt,
|
||||
prompt_len=input_len,
|
||||
expected_output_len=args.output_len,
|
||||
schema=schema,
|
||||
structure_type=args.structure_type)
|
||||
for _ in range(args.num_prompts)
|
||||
]
|
||||
|
||||
elif args.dataset == "grammar":
|
||||
schema = """
|
||||
?start: select_statement
|
||||
|
||||
?select_statement: "SELECT " column_list " FROM " table_name
|
||||
|
||||
?column_list: column_name ("," column_name)*
|
||||
|
||||
?table_name: identifier
|
||||
|
||||
?column_name: identifier
|
||||
|
||||
?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/
|
||||
"""
|
||||
prompt = "Generate an SQL query to show the 'username' \
|
||||
and 'email' from the 'users' table."
|
||||
|
||||
input_len = len(tokenizer(prompt).input_ids)
|
||||
print(f"Input length of the prompt: {input_len} tokens")
|
||||
requests = [
|
||||
SampleRequest(prompt=prompt,
|
||||
prompt_len=input_len,
|
||||
expected_output_len=args.output_len,
|
||||
schema=schema,
|
||||
structure_type=args.structure_type)
|
||||
for _ in range(args.num_prompts)
|
||||
]
|
||||
|
||||
elif args.dataset == "regex":
|
||||
regex = r"\w+@\w+\.com\n"
|
||||
args.regex = regex
|
||||
prompt = "Generate an email address for Alan Turing, \
|
||||
who works in Enigma. End in .com and new line. \
|
||||
Example result: alan.turing@enigma.com\n"
|
||||
|
||||
input_len = len(tokenizer(prompt).input_ids)
|
||||
print(f"Input length of the prompt: {input_len} tokens")
|
||||
requests = [
|
||||
SampleRequest(prompt=prompt,
|
||||
prompt_len=input_len,
|
||||
expected_output_len=args.output_len,
|
||||
schema=regex,
|
||||
structure_type=args.structure_type)
|
||||
for _ in range(args.num_prompts)
|
||||
]
|
||||
|
||||
elif args.dataset == "choice":
|
||||
choice = ["Positive", "Negative"]
|
||||
args.choice = choice
|
||||
prompt = "Classify this sentiment: vLLM is wonderful!"
|
||||
input_len = len(tokenizer(prompt).input_ids)
|
||||
print(f"Input length of the prompt: {input_len} tokens")
|
||||
requests = [
|
||||
SampleRequest(prompt=prompt,
|
||||
prompt_len=input_len,
|
||||
expected_output_len=args.output_len,
|
||||
schema=choice,
|
||||
structure_type=args.structure_type)
|
||||
for _ in range(args.num_prompts)
|
||||
]
|
||||
|
||||
elif args.dataset == "xgrammar_bench":
|
||||
args.warmup = False
|
||||
requests: list[SampleRequest] = []
|
||||
dataset = datasets.load_dataset("NousResearch/json-mode-eval",
|
||||
split="train")
|
||||
print(f"dataset has {len(dataset)} entries")
|
||||
len_dataset = len(dataset)
|
||||
for data_point_idx in range(args.num_prompts):
|
||||
idx = data_point_idx
|
||||
while idx >= len_dataset:
|
||||
idx -= len_dataset
|
||||
schema = dataset["schema"][idx]
|
||||
prompt = tokenizer.apply_chat_template(dataset["prompt"][idx],
|
||||
tokenize=False)
|
||||
input_len = len(tokenizer(prompt).input_ids)
|
||||
completion = dataset["completion"][idx]
|
||||
|
||||
requests.append(
|
||||
SampleRequest(prompt=prompt,
|
||||
prompt_len=input_len,
|
||||
expected_output_len=args.output_len,
|
||||
schema=schema,
|
||||
completion=completion))
|
||||
|
||||
return requests
|
||||
|
||||
|
||||
def evaluate(ret, args):
|
||||
|
||||
def _eval_correctness_json(expected, actual):
|
||||
# extract json string from string using regex
|
||||
import re
|
||||
actual = actual.replace('\n', '').replace(' ', '').strip()
|
||||
try:
|
||||
actual = re.search(r'\{.*\}', actual).group()
|
||||
actual = json.loads(actual)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _eval_correctness_choice(expected, actual):
|
||||
return actual in args.choice
|
||||
|
||||
def _eval_correctness_regex(expected, actual):
|
||||
import re
|
||||
return re.match(args.regex, actual) is not None
|
||||
|
||||
def _eval_correctness(expected, actual):
|
||||
if args.structure_type == 'json':
|
||||
return _eval_correctness_json(expected, actual)
|
||||
elif args.structure_type == 'regex':
|
||||
return _eval_correctness_regex(expected, actual)
|
||||
elif args.structure_type == 'choice':
|
||||
return _eval_correctness_choice(expected, actual)
|
||||
else:
|
||||
return None
|
||||
|
||||
scores = []
|
||||
for res in ret:
|
||||
score = _eval_correctness(res['expected'], res['generated'])
|
||||
res['correctness'] = score
|
||||
scores.append(score)
|
||||
|
||||
not_none_scores = [score for score in scores if score is not None]
|
||||
|
||||
return (sum(not_none_scores) / len(not_none_scores) *
|
||||
100) if len(not_none_scores) > 0 else None
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
print(args)
|
||||
random.seed(args.seed)
|
||||
|
||||
# async engine is working for 'regex', 'choice' and 'grammar'
|
||||
if args.dataset == 'grammar':
|
||||
args.structure_type = 'grammar'
|
||||
args.async_engine = False
|
||||
elif args.dataset == 'regex':
|
||||
args.structure_type = 'regex'
|
||||
args.async_engine = False
|
||||
elif args.dataset == 'choice':
|
||||
args.structure_type = 'choice'
|
||||
args.async_engine = False
|
||||
else:
|
||||
args.structure_type = 'json'
|
||||
|
||||
if args.no_guided_decoding:
|
||||
args.guided_decoding_ratio = 0
|
||||
if args.save_results:
|
||||
result_file_name = f'{args.guided_decoding_ratio}guided'
|
||||
result_file_name += f"_{args.model.split('/')[-1]}"
|
||||
result_file_name += f"_{args.dataset}"
|
||||
result_file_name += f"_{args.num_prompts}"
|
||||
result_file_name += f"_out{args.output_len}"
|
||||
result_file_name += f"_async{args.async_engine}"
|
||||
result_file_name += f"_warmup{args.warmup}"
|
||||
result_file_name += f"_chunkedprefill{args.enable_chunked_prefill}"
|
||||
result_file_name += ".txt"
|
||||
else:
|
||||
result_file_name = None
|
||||
|
||||
# Synthesize a prompt with the given input length.
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
requests = sample_requests(tokenizer, args)
|
||||
|
||||
if args.async_engine:
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
elapsed_time, ret, (first_latency, next_latency) = uvloop.run(
|
||||
run_vllm_async(requests, engine_args, args.n,
|
||||
args.guided_decoding_ratio, args.warmup,
|
||||
args.disable_frontend_multiprocessing))
|
||||
else:
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
elapsed_time, ret = run_vllm(requests, engine_args, args.n,
|
||||
args.guided_decoding_ratio, args.warmup)
|
||||
first_latency, next_latency = None, None
|
||||
|
||||
score = evaluate(ret, args)
|
||||
total_num_tokens = sum(request.prompt_len + request.expected_output_len
|
||||
for request in requests)
|
||||
total_output_tokens = sum(request.expected_output_len
|
||||
for request in requests)
|
||||
if first_latency is not None:
|
||||
latency_breakdown = "\nFirst token latency(msecs):\n"
|
||||
latency_breakdown += f"{first_latency.describe()}"
|
||||
latency_breakdown += "\nNext token latency(msecs):\n"
|
||||
latency_breakdown += f"{next_latency.describe()}"
|
||||
print(
|
||||
f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||
f"{total_num_tokens / elapsed_time:.2f} total tokens/s, "
|
||||
f"{total_output_tokens / elapsed_time:.2f} output tokens/s",
|
||||
f"Correct rate is {score} %",
|
||||
f"{latency_breakdown if first_latency is not None else ''}")
|
||||
|
||||
# Output JSON results if specified
|
||||
if args.output_json or result_file_name:
|
||||
results = {
|
||||
"elapsed_time": elapsed_time,
|
||||
"num_requests": len(requests),
|
||||
"total_num_tokens": total_num_tokens,
|
||||
"total_output_tokens": total_output_tokens,
|
||||
"requests_per_second": len(requests) / elapsed_time,
|
||||
"tokens_per_second": f"{total_num_tokens / elapsed_time:.2f}",
|
||||
"output_tokens_per_second":
|
||||
f"{total_output_tokens / elapsed_time:.2f}",
|
||||
"correct_rate(%)": score
|
||||
}
|
||||
results = {"outputs": ret, **results}
|
||||
if first_latency is not None:
|
||||
results["first_token_latency(msecs)"] = first_latency.describe(
|
||||
).to_dict()
|
||||
results["next_token_latency(msecs)"] = next_latency.describe(
|
||||
).to_dict()
|
||||
if args.output_json:
|
||||
with open(args.output_json, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
elif result_file_name:
|
||||
with open(result_file_name, "w") as f:
|
||||
json.dump(results, f, indent=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = FlexibleArgumentParser(description="Benchmark guided decoding.")
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
parser.add_argument("--output-len",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Output length for each request. Overrides the "
|
||||
"output length from the dataset.")
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
default='json',
|
||||
choices=['json', 'grammar', 'regex', 'choice', 'xgrammar_bench'])
|
||||
parser.add_argument("--json_schema_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to json schema.")
|
||||
parser.add_argument("--n",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of generated sequences per prompt.")
|
||||
parser.add_argument("--num-prompts",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of prompts to process.")
|
||||
parser.add_argument(
|
||||
'--output-json',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Path to save the throughput results in JSON format.')
|
||||
parser.add_argument("--async-engine",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Use vLLM async engine rather than LLM class.")
|
||||
parser.add_argument("--no-guided-decoding",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Whether to disable JSON decoding or not.")
|
||||
parser.add_argument("--guided-decoding-ratio",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Ratio of Guided Decoding requests")
|
||||
parser.add_argument("--disable-frontend-multiprocessing",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Disable decoupled async engine frontend.")
|
||||
parser.add_argument("--warmup",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Run warmup prompts before benchmark.")
|
||||
parser.add_argument("--save-results",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="save output results.")
|
||||
args = parser.parse_args()
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
main(args)
|
@ -1,5 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
r"""Benchmark online serving throughput with guided decoding.
|
||||
r"""Benchmark online serving throughput with structured outputs.
|
||||
|
||||
On the server side, run one of the following commands:
|
||||
(vLLM OpenAI API server)
|
||||
@ -9,12 +9,12 @@ On the server side, run one of the following commands:
|
||||
./launch_tgi_server.sh <your_model> <max_batch_total_tokens>
|
||||
|
||||
On the client side, run:
|
||||
python benchmarks/benchmark_serving_guided.py \
|
||||
python benchmarks/benchmark_serving_structured_output.py \
|
||||
--backend <backend> \
|
||||
--model <your_model> \
|
||||
--dataset json \
|
||||
--guided-decoding-ratio 1.0 \
|
||||
--guided-decoding-backend xgrammar \
|
||||
--structured-output-ratio 1.0 \
|
||||
--structured-output-backend xgrammar \
|
||||
--request-rate 10 \
|
||||
--num-prompts 1000
|
||||
|
||||
@ -52,6 +52,9 @@ try:
|
||||
except ImportError:
|
||||
from argparse import ArgumentParser as FlexibleArgumentParser
|
||||
|
||||
from vllm.v1.structured_output.utils import (
|
||||
has_xgrammar_unsupported_json_features)
|
||||
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
||||
|
||||
|
||||
@ -191,7 +194,17 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||
requests: list[SampleRequest] = []
|
||||
dataset = datasets.load_dataset("NousResearch/json-mode-eval",
|
||||
split="train")
|
||||
print(f"dataset has {len(dataset)} entries")
|
||||
full_dataset_len = len(dataset)
|
||||
|
||||
def _filter_func(item):
|
||||
import json
|
||||
schema = json.loads(item["schema"])
|
||||
return not has_xgrammar_unsupported_json_features(schema)
|
||||
|
||||
dataset = dataset.filter(_filter_func)
|
||||
num_filtered_out = full_dataset_len - len(dataset)
|
||||
print(f"dataset has {len(dataset)} entries after filtering "
|
||||
f"out {num_filtered_out} entries with unsupported features")
|
||||
len_dataset = len(dataset)
|
||||
for data_point_idx in range(args.num_prompts):
|
||||
idx = data_point_idx
|
||||
@ -378,8 +391,8 @@ async def benchmark(
|
||||
selected_percentiles: list[str],
|
||||
ignore_eos: bool,
|
||||
max_concurrency: Optional[int],
|
||||
guided_decoding_ratio: float,
|
||||
guided_decoding_backend: str,
|
||||
structured_output_ratio: float,
|
||||
structured_output_backend: str,
|
||||
goodput_config_dict: Optional[dict[str, float]] = None,
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
@ -391,16 +404,18 @@ async def benchmark(
|
||||
extra_body = {}
|
||||
# Add the schema to the extra_body
|
||||
extra_body[request.structure_type] = request.schema
|
||||
# Add the specific guided_decoding_backend
|
||||
extra_body["guided_decoding_backend"] = guided_decoding_backend
|
||||
# Add the specific structured_output_backend
|
||||
extra_body["guided_decoding_backend"] = structured_output_backend
|
||||
return extra_body
|
||||
|
||||
print("Starting initial single prompt test run...")
|
||||
guided_decoding_req_idx = random.sample(
|
||||
structured_output_req_idx = random.sample(
|
||||
range(len(input_requests)),
|
||||
int(len(input_requests) * guided_decoding_ratio))
|
||||
int(len(input_requests) * structured_output_ratio))
|
||||
|
||||
test_request = input_requests[0]
|
||||
test_req_extra_body = (prepare_extra_body(test_request)
|
||||
if 0 in structured_output_req_idx else None)
|
||||
test_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=test_request.prompt,
|
||||
@ -408,7 +423,7 @@ async def benchmark(
|
||||
prompt_len=test_request.prompt_len,
|
||||
output_len=test_request.expected_output_len,
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=prepare_extra_body(test_request),
|
||||
extra_body=test_req_extra_body,
|
||||
)
|
||||
test_output = await request_func(request_func_input=test_input)
|
||||
if not test_output.success:
|
||||
@ -427,7 +442,7 @@ async def benchmark(
|
||||
prompt_len=test_request.prompt_len,
|
||||
output_len=test_request.expected_output_len,
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=prepare_extra_body(test_request),
|
||||
extra_body=test_req_extra_body,
|
||||
)
|
||||
profile_output = await request_func(request_func_input=profile_input)
|
||||
if profile_output.success:
|
||||
@ -465,7 +480,7 @@ async def benchmark(
|
||||
async for i, request in get_request(input_requests, request_rate,
|
||||
burstiness):
|
||||
extra_body = prepare_extra_body(
|
||||
request) if i in guided_decoding_req_idx else None
|
||||
request) if i in structured_output_req_idx else None
|
||||
request_func_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=request.prompt,
|
||||
@ -708,10 +723,10 @@ def main(args: argparse.Namespace):
|
||||
else:
|
||||
args.structure_type = 'guided_json'
|
||||
|
||||
if args.no_guided_decoding:
|
||||
args.guided_decoding_ratio = 0
|
||||
if args.no_structured_output:
|
||||
args.structured_output_ratio = 0
|
||||
if args.save_results:
|
||||
result_file_name = f'{args.guided_decoding_ratio}guided'
|
||||
result_file_name = f'{args.structured_output_ratio}guided'
|
||||
result_file_name += f"_{backend}"
|
||||
result_file_name += f"_{args.request_rate}qps"
|
||||
result_file_name += f"_{args.model.split('/')[-1]}"
|
||||
@ -744,8 +759,8 @@ def main(args: argparse.Namespace):
|
||||
],
|
||||
ignore_eos=args.ignore_eos,
|
||||
max_concurrency=args.max_concurrency,
|
||||
guided_decoding_ratio=args.guided_decoding_ratio,
|
||||
guided_decoding_backend=args.guided_decoding_backend,
|
||||
structured_output_ratio=args.structured_output_ratio,
|
||||
structured_output_backend=args.structured_output_backend,
|
||||
goodput_config_dict=goodput_config_dict,
|
||||
))
|
||||
|
||||
@ -943,19 +958,19 @@ if __name__ == "__main__":
|
||||
"goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 "
|
||||
"and the blog: https://hao-ai-lab.github.io/blogs/distserve")
|
||||
|
||||
parser.add_argument("--no-guided-decoding",
|
||||
parser.add_argument("--no-structured-output",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Whether to disable JSON decoding or not.")
|
||||
parser.add_argument("--guided-decoding-ratio",
|
||||
parser.add_argument("--structured-output-ratio",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Ratio of Guided Decoding requests")
|
||||
parser.add_argument("--guided-decoding-backend",
|
||||
help="Ratio of Structured Outputs requests")
|
||||
parser.add_argument("--structured-output-backend",
|
||||
type=str,
|
||||
choices=["outlines", "lm-format-enforcer", "xgrammar"],
|
||||
default="xgrammar",
|
||||
help="Backend to use for guided decoding")
|
||||
help="Backend to use for structured outputs")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
64
benchmarks/run_structured_output_benchmark.sh
Executable file
64
benchmarks/run_structured_output_benchmark.sh
Executable file
@ -0,0 +1,64 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Define the model to use
|
||||
MODEL=${1:-"Qwen/Qwen2.5-7B-Instruct"}
|
||||
|
||||
# Define the backend to use
|
||||
BACKEND=${2:-"vllm"}
|
||||
|
||||
# Define the dataset to use
|
||||
DATASET=${3:-"xgrammar_bench"}
|
||||
|
||||
# Define the guided decoding backend
|
||||
GUIDED_BACKEND=${4:-"xgrammar"}
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
OUTPUT_DIR=${5:-"$SCRIPT_DIR/structured_output_benchmark_results"}
|
||||
|
||||
GUIDED_RATIO=${6:-0.5}
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
|
||||
# Define QPS values to test
|
||||
QPS_VALUES=(70 60 50 25 20 15 10)
|
||||
|
||||
# Common parameters
|
||||
COMMON_PARAMS="--backend $BACKEND \
|
||||
--model $MODEL \
|
||||
--dataset $DATASET \
|
||||
--structured-output-backend $GUIDED_BACKEND \
|
||||
--structured-output-ratio $GUIDED_RATIO \
|
||||
--save-results \
|
||||
--result-dir $OUTPUT_DIR"
|
||||
|
||||
echo "Starting structured output benchmark with model: $MODEL"
|
||||
echo "Backend: $BACKEND"
|
||||
echo "Dataset: $DATASET"
|
||||
echo "Structured output backend: $GUIDED_BACKEND"
|
||||
echo "Results will be saved to: $OUTPUT_DIR"
|
||||
echo "----------------------------------------"
|
||||
|
||||
# Run benchmarks with different QPS values
|
||||
for qps in "${QPS_VALUES[@]}"; do
|
||||
echo "Running benchmark with QPS: $qps"
|
||||
|
||||
# Get git hash and branch for the filename
|
||||
GIT_HASH=$(git rev-parse --short HEAD 2>/dev/null || echo "unknown")
|
||||
GIT_BRANCH=$(git rev-parse --abbrev-ref HEAD 2>/dev/null || echo "unknown")
|
||||
|
||||
# Construct filename for this run
|
||||
FILENAME="${GUIDED_BACKEND}_${BACKEND}_${qps}qps_$(basename $MODEL)_${DATASET}_${GIT_HASH}.json"
|
||||
|
||||
# Run the benchmark
|
||||
python "$SCRIPT_DIR/benchmark_serving_structured_output.py" $COMMON_PARAMS \
|
||||
--request-rate $qps \
|
||||
--result-filename "$FILENAME" \
|
||||
--port ${PORT:-8000}
|
||||
|
||||
echo "Completed benchmark with QPS: $qps"
|
||||
echo "----------------------------------------"
|
||||
done
|
||||
|
||||
echo "All benchmarks completed!"
|
||||
echo "Results saved to: $OUTPUT_DIR"
|
@ -1,113 +1,25 @@
|
||||
{
|
||||
"$schema":
|
||||
"https://json-schema.org/draft/2020-12/schema",
|
||||
"title":
|
||||
"User Profile",
|
||||
"type":
|
||||
"object",
|
||||
"properties": {
|
||||
"userId": {
|
||||
"type": "string",
|
||||
"description": "Unique identifier for the user."
|
||||
},
|
||||
"personalInfo": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"firstName": {
|
||||
"type": "string",
|
||||
"description": "The user's first name."
|
||||
},
|
||||
"lastName": {
|
||||
"type": "string",
|
||||
"description": "The user's last name."
|
||||
},
|
||||
"age": {
|
||||
"type": "integer",
|
||||
"minimum": 0,
|
||||
"description": "The user's age."
|
||||
},
|
||||
"phoneNumbers": {
|
||||
"type":
|
||||
"array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {
|
||||
"type": "string",
|
||||
"enum": ["home", "work", "mobile"],
|
||||
"description": "Type of phone number."
|
||||
},
|
||||
"number": {
|
||||
"type": "string",
|
||||
"pattern": "^\\+?[1-9]\\d{1,14}$",
|
||||
"description": "Phone number in E.164 format."
|
||||
}
|
||||
},
|
||||
"required": ["type", "number"]
|
||||
},
|
||||
"description":
|
||||
"List of phone numbers associated with the user."
|
||||
}
|
||||
},
|
||||
"required": ["firstName", "lastName"]
|
||||
},
|
||||
"address": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"street": {
|
||||
"type": "string",
|
||||
"description": "Street address."
|
||||
},
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "City name."
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"description": "State or province."
|
||||
},
|
||||
"postalCode": {
|
||||
"type": "string",
|
||||
"pattern": "^\\d{5}(-\\d{4})?$",
|
||||
"description": "Postal code."
|
||||
},
|
||||
"country": {
|
||||
"type": "string",
|
||||
"description": "Country name."
|
||||
}
|
||||
},
|
||||
"required": ["street", "city", "state", "postalCode", "country"]
|
||||
},
|
||||
"preferences": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"newsletterSubscribed": {
|
||||
"type":
|
||||
"boolean",
|
||||
"description":
|
||||
"Indicates if the user is subscribed to the newsletter."
|
||||
},
|
||||
"favoriteCategories": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" },
|
||||
"race": { "type": "string" },
|
||||
"class": { "type": "string" },
|
||||
"level": { "type": "integer" },
|
||||
"background": { "type": "string" },
|
||||
"alignment": { "type": "string" },
|
||||
"backstory": { "type": "string" }
|
||||
},
|
||||
"description": "List of user's favorite categories."
|
||||
"required": [
|
||||
"name",
|
||||
"race",
|
||||
"class",
|
||||
"level",
|
||||
"background",
|
||||
"alignment",
|
||||
"backstory"
|
||||
]
|
||||
}
|
||||
},
|
||||
"required": ["newsletterSubscribed"]
|
||||
},
|
||||
"accountStatus": {
|
||||
"type": "string",
|
||||
"enum": ["active", "inactive", "suspended"],
|
||||
"description": "Current status of the user's account."
|
||||
},
|
||||
"registrationDate": {
|
||||
"type": "string",
|
||||
"format": "date-time",
|
||||
"description": "ISO 8601 formatted date-time of user registration."
|
||||
}
|
||||
},
|
||||
"required":
|
||||
["userId", "personalInfo", "address", "accountStatus", "registrationDate"]
|
||||
}
|
||||
|
||||
|
@ -1,12 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Optional
|
||||
|
||||
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
|
||||
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
EOS_TOKEN_ID = 50256
|
||||
|
||||
@ -36,13 +37,21 @@ def create_scheduler(
|
||||
swap_space=0,
|
||||
cache_dtype="auto",
|
||||
)
|
||||
vllm_config = VllmConfig(
|
||||
scheduler_config=scheduler_config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
)
|
||||
cache_config.num_gpu_blocks = 10000
|
||||
return Scheduler(scheduler_config,
|
||||
return Scheduler(
|
||||
scheduler_config,
|
||||
model_config,
|
||||
cache_config,
|
||||
speculative_config=None,
|
||||
lora_config=None,
|
||||
log_stats=True)
|
||||
log_stats=True,
|
||||
structured_output_manager=StructuredOutputManager(vllm_config),
|
||||
)
|
||||
|
||||
|
||||
def create_requests(
|
||||
@ -249,7 +258,9 @@ def test_stop_via_update_from_output():
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[])
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
@ -299,7 +310,9 @@ def test_stop_via_update_from_output():
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[])
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
@ -347,7 +360,9 @@ def test_stop_via_update_from_output():
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[])
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
@ -392,7 +407,9 @@ def test_stop_via_update_from_output():
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[])
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
|
@ -29,6 +29,7 @@ def sample_regex():
|
||||
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
|
||||
|
||||
|
||||
# Note: Ensure this only uses attributes compatible with xgrammar
|
||||
@pytest.fixture
|
||||
def sample_json_schema():
|
||||
return {
|
||||
@ -44,9 +45,7 @@ def sample_json_schema():
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"maxLength": 10
|
||||
},
|
||||
"minItems": 3
|
||||
}
|
||||
},
|
||||
"work_history": {
|
||||
"type": "array",
|
||||
@ -71,8 +70,9 @@ def sample_json_schema():
|
||||
}
|
||||
|
||||
|
||||
# A schema unsupported by xgrammar
|
||||
@pytest.fixture
|
||||
def sample_complex_json_schema():
|
||||
def unsupported_json_schema():
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@ -150,7 +150,19 @@ def sample_guided_choice():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_sql_statements():
|
||||
def sample_sql_ebnf():
|
||||
return """
|
||||
root ::= select_statement
|
||||
select_statement ::= "SELECT" column "from" table "where" condition
|
||||
column ::= "col_1" | "col_2"
|
||||
table ::= "table_1" | "table_2"
|
||||
condition ::= column "=" number
|
||||
number ::= "1" | "2"
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_sql_lark():
|
||||
return ("""
|
||||
start: select_statement
|
||||
select_statement: "SELECT" column "from" table "where" condition
|
||||
|
0
tests/v1/entrypoints/llm/__init__.py
Normal file
0
tests/v1/entrypoints/llm/__init__.py
Normal file
269
tests/v1/entrypoints/llm/test_struct_output_generate.py
Normal file
269
tests/v1/entrypoints/llm/test_struct_output_generate.py
Normal file
@ -0,0 +1,269 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
|
||||
import jsonschema
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
GUIDED_DECODING_BACKENDS_V1 = ["xgrammar"]
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
GUIDED_DECODING_BACKENDS_V1)
|
||||
def test_guided_json_completion(monkeypatch, sample_json_schema,
|
||||
guided_decoding_backend: str):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||
sampling_params = SamplingParams(temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=sample_json_schema,
|
||||
backend=guided_decoding_backend))
|
||||
outputs = llm.generate(prompts=[
|
||||
f"Give an example JSON for an employee profile "
|
||||
f"that fits this schema: {sample_json_schema}"
|
||||
] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
output_json = json.loads(generated_text)
|
||||
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
GUIDED_DECODING_BACKENDS_V1)
|
||||
def test_guided_json_object(monkeypatch, guided_decoding_backend: str):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||
sampling_params = SamplingParams(temperature=1.0,
|
||||
max_tokens=100,
|
||||
n=2,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json_object=True,
|
||||
backend=guided_decoding_backend))
|
||||
|
||||
outputs = llm.generate(
|
||||
prompts=("Generate a JSON object with curly braces for a person with "
|
||||
"name and age fields for John Smith who is 31 years old."),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
|
||||
for i in range(2):
|
||||
generated_text = output.outputs[i].text
|
||||
print(generated_text)
|
||||
assert generated_text is not None
|
||||
|
||||
# Parse to verify it is valid JSON
|
||||
parsed_json = json.loads(generated_text)
|
||||
assert isinstance(parsed_json, dict)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
GUIDED_DECODING_BACKENDS_V1)
|
||||
def test_guided_json_unsupported_schema(monkeypatch, unsupported_json_schema,
|
||||
guided_decoding_backend: str):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||
sampling_params = SamplingParams(temperature=1.0,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
json=unsupported_json_schema,
|
||||
backend=guided_decoding_backend))
|
||||
with pytest.raises(ValueError,
|
||||
match="The provided JSON schema contains features "
|
||||
"not supported by xgrammar."):
|
||||
llm.generate(prompts=[
|
||||
f"Give an example JSON for an employee profile "
|
||||
f"that fits this schema: {unsupported_json_schema}"
|
||||
] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
GUIDED_DECODING_BACKENDS_V1)
|
||||
def test_guided_grammar_ebnf(monkeypatch, sample_sql_ebnf,
|
||||
guided_decoding_backend: str):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
grammar=sample_sql_ebnf,
|
||||
backend=guided_decoding_backend))
|
||||
outputs = llm.generate(
|
||||
prompts=("Generate a sql statement that selects col_1 from "
|
||||
"table_1 where it is equal to 1"),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
|
||||
# remove spaces for comparison b/c we removed them in the grammar
|
||||
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
|
||||
" ", "")
|
||||
|
||||
assert generated_text.strip() == ground_truth
|
||||
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
GUIDED_DECODING_BACKENDS_V1)
|
||||
def test_guided_grammar_lark(monkeypatch, sample_sql_lark,
|
||||
guided_decoding_backend: str):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
grammar=sample_sql_lark,
|
||||
backend=guided_decoding_backend))
|
||||
outputs = llm.generate(
|
||||
prompts=("Generate a sql statement that selects col_1 from "
|
||||
"table_1 where it is equal to 1"),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
|
||||
generated_text = output.outputs[0].text
|
||||
assert generated_text is not None
|
||||
|
||||
# use Lark to parse the output, and make sure it's a valid parse tree
|
||||
from lark import Lark
|
||||
parser = Lark(sample_sql_lark)
|
||||
parser.parse(generated_text)
|
||||
|
||||
# remove spaces for comparison b/c we removed them in the grammar
|
||||
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
|
||||
" ", "")
|
||||
|
||||
assert generated_text.strip() == ground_truth
|
||||
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
GUIDED_DECODING_BACKENDS_V1)
|
||||
def test_guided_grammar_ebnf_invalid(monkeypatch,
|
||||
guided_decoding_backend: str):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
top_p=0.95,
|
||||
max_tokens=1000,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
grammar="not a grammar",
|
||||
backend=guided_decoding_backend))
|
||||
with pytest.raises(ValueError,
|
||||
match="Failed to convert the grammar "
|
||||
"from Lark to EBNF."):
|
||||
llm.generate(
|
||||
prompts=("Generate a sql statement that selects col_1 from "
|
||||
"table_1 where it is equal to 1"),
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
GUIDED_DECODING_BACKENDS_V1)
|
||||
def test_guided_regex(monkeypatch, sample_regex, guided_decoding_backend: str):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
top_p=0.95,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
regex=sample_regex,
|
||||
backend=guided_decoding_backend))
|
||||
with pytest.raises(ValueError,
|
||||
match="Regex guided decoding is not supported."):
|
||||
llm.generate(prompts=[
|
||||
f"Give an example IPv4 address with this regex: {sample_regex}"
|
||||
] * 2,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
# Once regex is supported --
|
||||
#assert outputs is not None
|
||||
#for output in outputs:
|
||||
# assert output is not None
|
||||
# assert isinstance(output, RequestOutput)
|
||||
# prompt = output.prompt
|
||||
# generated_text = output.outputs[0].text
|
||||
# print(generated_text)
|
||||
# assert generated_text is not None
|
||||
# assert re.fullmatch(sample_regex, generated_text) is not None
|
||||
# print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
GUIDED_DECODING_BACKENDS_V1)
|
||||
def test_guided_choice_completion(monkeypatch, sample_guided_choice,
|
||||
guided_decoding_backend: str):
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
llm = LLM(model=MODEL_NAME, max_model_len=1024)
|
||||
sampling_params = SamplingParams(temperature=0.8,
|
||||
top_p=0.95,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
choice=sample_guided_choice,
|
||||
backend=guided_decoding_backend))
|
||||
outputs = llm.generate(
|
||||
prompts="The best language for type-safe systems programming is ",
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=True)
|
||||
|
||||
assert outputs is not None
|
||||
for output in outputs:
|
||||
assert output is not None
|
||||
assert isinstance(output, RequestOutput)
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(generated_text)
|
||||
assert generated_text is not None
|
||||
assert generated_text in sample_guided_choice
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
0
tests/v1/structured_output/__init__.py
Normal file
0
tests/v1/structured_output/__init__.py
Normal file
196
tests/v1/structured_output/test_utils.py
Normal file
196
tests/v1/structured_output/test_utils.py
Normal file
@ -0,0 +1,196 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.v1.structured_output.utils import (
|
||||
has_xgrammar_unsupported_json_features)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unsupported_string_schemas():
|
||||
return [
|
||||
{
|
||||
"type": "string",
|
||||
"pattern": "^[a-zA-Z]+$"
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"enum": ["active", "inactive", "pending"]
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"maxLength": 100
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"format": "email"
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unsupported_integer_schemas():
|
||||
return [
|
||||
{
|
||||
"type": "integer",
|
||||
"minimum": 0
|
||||
},
|
||||
{
|
||||
"type": "integer",
|
||||
"maximum": 120
|
||||
},
|
||||
{
|
||||
"type": "integer",
|
||||
"exclusiveMinimum": 120
|
||||
},
|
||||
{
|
||||
"type": "integer",
|
||||
"exclusiveMaximum": 120
|
||||
},
|
||||
{
|
||||
"type": "integer",
|
||||
"multipleOf": 120
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unsupported_number_schemas():
|
||||
return [
|
||||
{
|
||||
"type": "number",
|
||||
"minimum": 0
|
||||
},
|
||||
{
|
||||
"type": "number",
|
||||
"maximum": 120
|
||||
},
|
||||
{
|
||||
"type": "number",
|
||||
"exclusiveMinimum": 120
|
||||
},
|
||||
{
|
||||
"type": "number",
|
||||
"exclusiveMaximum": 120
|
||||
},
|
||||
{
|
||||
"type": "number",
|
||||
"multipleOf": 120
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unsupported_array_schemas():
|
||||
return [
|
||||
{
|
||||
"type": "array",
|
||||
"uniqueItems": True
|
||||
},
|
||||
{
|
||||
"type": "array",
|
||||
"contains": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "array",
|
||||
"minContains": 1
|
||||
},
|
||||
{
|
||||
"type": "array",
|
||||
"maxContains": 5
|
||||
},
|
||||
{
|
||||
"type": "array",
|
||||
"minItems": 1
|
||||
},
|
||||
{
|
||||
"type": "array",
|
||||
"maxItems": 10
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unsupported_object_schemas():
|
||||
return [
|
||||
{
|
||||
"type": "object",
|
||||
"minProperties": 1
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"maxProperties": 5
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"propertyNames": {
|
||||
"pattern": "^[a-z]+$"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"patternProperties": {
|
||||
"^S": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def supported_schema():
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string"
|
||||
},
|
||||
"age": {
|
||||
"type": "integer"
|
||||
},
|
||||
"status": {
|
||||
"type": "string"
|
||||
},
|
||||
"scores": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "number"
|
||||
}
|
||||
},
|
||||
"address": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"street": {
|
||||
"type": "string"
|
||||
},
|
||||
"city": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("schema_type", [
|
||||
"unsupported_string_schemas", "unsupported_integer_schemas",
|
||||
"unsupported_number_schemas", "unsupported_array_schemas",
|
||||
"unsupported_object_schemas"
|
||||
])
|
||||
def test_unsupported_json_features_by_type(schema_type, request):
|
||||
schemas = request.getfixturevalue(schema_type)
|
||||
for schema in schemas:
|
||||
assert has_xgrammar_unsupported_json_features(
|
||||
schema), f"Schema should be unsupported: {schema}"
|
||||
|
||||
|
||||
def test_supported_json_features(supported_schema):
|
||||
assert not has_xgrammar_unsupported_json_features(
|
||||
supported_schema), "Schema should be supported"
|
@ -72,6 +72,8 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
|
||||
@ -135,6 +137,8 @@ def test_update_states_request_finished(model_runner):
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids={req_id},
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
metadata_before = model_runner.input_batch.sampling_metadata
|
||||
@ -165,6 +169,8 @@ def test_update_states_request_resumed(model_runner):
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
model_runner._update_states(scheduler_output)
|
||||
@ -190,6 +196,8 @@ def test_update_states_request_resumed(model_runner):
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
metadata_before = model_runner.input_batch.sampling_metadata
|
||||
@ -221,6 +229,8 @@ def test_update_states_no_changes(model_runner):
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
metadata_before = model_runner.input_batch.sampling_metadata
|
||||
@ -256,6 +266,8 @@ def test_update_states_request_unscheduled(model_runner):
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
metadata_before = model_runner._update_states(scheduler_output)
|
||||
|
@ -1,5 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import concurrent
|
||||
@ -8,6 +10,7 @@ import datetime
|
||||
import enum
|
||||
import gc
|
||||
import getpass
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
import importlib.util
|
||||
import inspect
|
||||
@ -23,6 +26,7 @@ import tempfile
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import types
|
||||
import uuid
|
||||
import warnings
|
||||
import weakref
|
||||
@ -982,7 +986,7 @@ def current_stream() -> torch.cuda.Stream:
|
||||
return _current_stream
|
||||
|
||||
|
||||
def enable_trace_function_call_for_thread(vllm_config: "VllmConfig") -> None:
|
||||
def enable_trace_function_call_for_thread(vllm_config: VllmConfig) -> None:
|
||||
"""Set up function tracing for the current thread,
|
||||
if enabled via the VLLM_TRACE_FUNCTION environment variable
|
||||
"""
|
||||
@ -1977,7 +1981,7 @@ class MemorySnapshot:
|
||||
self.non_torch_memory = self.cuda_memory - self.torch_memory
|
||||
self.timestamp = time.time()
|
||||
|
||||
def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
|
||||
def __sub__(self, other: MemorySnapshot) -> MemorySnapshot:
|
||||
return MemorySnapshot(
|
||||
torch_peak=self.torch_peak - other.torch_peak,
|
||||
cuda_memory=self.cuda_memory - other.cuda_memory,
|
||||
@ -2306,3 +2310,54 @@ def warn_for_unimplemented_methods(cls: type[T]) -> type[T]:
|
||||
|
||||
type.__setattr__(cls, '__init__', wrapped_init)
|
||||
return cls
|
||||
|
||||
|
||||
class LazyLoader(types.ModuleType):
|
||||
"""
|
||||
LazyLoader module borrowed from Tensorflow
|
||||
https://github.com/tensorflow/tensorflow/blob/main/tensorflow/python/util/lazy_loader.py
|
||||
with a addition of "module caching".
|
||||
|
||||
Lazily import a module, mainly to avoid pulling in large dependencies.
|
||||
Modules such as `xgrammar` might do additional side effects, so we
|
||||
only want to use this when it is needed, delaying all eager effects
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
local_name: str,
|
||||
parent_module_globals: dict[str, Any],
|
||||
name: str,
|
||||
):
|
||||
self._local_name = local_name
|
||||
self._parent_module_globals = parent_module_globals
|
||||
self._module: types.ModuleType | None = None
|
||||
|
||||
super().__init__(str(name))
|
||||
|
||||
def _load(self) -> types.ModuleType:
|
||||
# Import the target module and insert it into the parent's namespace
|
||||
try:
|
||||
module = importlib.import_module(self.__name__)
|
||||
self._parent_module_globals[self._local_name] = module
|
||||
# The additional add to sys.modules
|
||||
# ensures library is actually loaded.
|
||||
sys.modules[self._local_name] = module
|
||||
except ModuleNotFoundError as err:
|
||||
raise err from None
|
||||
|
||||
# Update this object's dict so that if someone keeps a
|
||||
# reference to the LazyLoader, lookups are efficient
|
||||
# (__getattr__ is only called on lookups that fail).
|
||||
self.__dict__.update(module.__dict__)
|
||||
return module
|
||||
|
||||
def __getattr__(self, item: Any) -> Any:
|
||||
if self._module is None:
|
||||
self._module = self._load()
|
||||
return getattr(self._module, item)
|
||||
|
||||
def __dir__(self) -> list[str]:
|
||||
if self._module is None:
|
||||
self._module = self._load()
|
||||
return dir(self._module)
|
||||
|
@ -1,5 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections import deque
|
||||
from collections.abc import Iterable
|
||||
@ -18,6 +20,7 @@ from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -32,12 +35,14 @@ class Scheduler:
|
||||
lora_config: Optional[LoRAConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
log_stats: bool,
|
||||
structured_output_manager: StructuredOutputManager,
|
||||
) -> None:
|
||||
self.scheduler_config = scheduler_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.speculative_config = speculative_config
|
||||
self.log_stats = log_stats
|
||||
self.structured_output_manager = structured_output_manager
|
||||
|
||||
# Scheduling constraints.
|
||||
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
|
||||
@ -97,7 +102,7 @@ class Scheduler:
|
||||
self.encoder_cache_manager = EncoderCacheManager(
|
||||
cache_size=encoder_cache_size)
|
||||
|
||||
def schedule(self) -> "SchedulerOutput":
|
||||
def schedule(self) -> SchedulerOutput:
|
||||
# NOTE(woosuk) on the scheduling algorithm:
|
||||
# There's no "decoding phase" nor "prefill phase" in the scheduler.
|
||||
# Each request just has the num_computed_tokens and
|
||||
@ -114,6 +119,14 @@ class Scheduler:
|
||||
scheduled_running_reqs: list[Request] = []
|
||||
preempted_reqs: list[Request] = []
|
||||
|
||||
# NOTE: structured_output_request_ids maps
|
||||
# a request's (request that uses structured output)
|
||||
# request_id to the running request index.
|
||||
# This will helps us determine to slice the grammar bitmask
|
||||
# and only applies valid mask for requests that
|
||||
# uses structured decoding.
|
||||
structured_output_request_ids: dict[str, int] = {}
|
||||
|
||||
req_to_new_block_ids: dict[str, list[int]] = {}
|
||||
num_scheduled_tokens: dict[str, int] = {}
|
||||
token_budget = self.max_num_scheduled_tokens
|
||||
@ -184,6 +197,12 @@ class Scheduler:
|
||||
# Schedule the request.
|
||||
scheduled_running_reqs.append(request)
|
||||
self.scheduled_req_ids.add(request.request_id)
|
||||
if request.use_structured_output:
|
||||
# PERF: in case of chunked prefill,
|
||||
# request might not include any new tokens.
|
||||
# Therefore, we might introduce some additional
|
||||
# cycle to fill in the bitmask, which could be a big no-op.
|
||||
structured_output_request_ids[request.request_id] = req_index
|
||||
req_to_new_block_ids[request.request_id] = [
|
||||
b.block_id for b in new_blocks
|
||||
]
|
||||
@ -219,6 +238,10 @@ class Scheduler:
|
||||
if req.lora_request and req.lora_request.lora_int_id > 0)
|
||||
assert len(requested_loras) <= self.lora_config.max_loras
|
||||
|
||||
# Use a temporary deque to collect requests that need to be skipped
|
||||
# and put back at the head of the waiting queue later
|
||||
waiting_for_fsm: deque[Request] = deque()
|
||||
|
||||
# Next, schedule the WAITING requests.
|
||||
if not preempted_reqs:
|
||||
while self.waiting and token_budget > 0:
|
||||
@ -227,6 +250,16 @@ class Scheduler:
|
||||
|
||||
request = self.waiting[0]
|
||||
|
||||
if request.status == RequestStatus.WAITING_FOR_FSM:
|
||||
structured_output_req = request.structured_output_request
|
||||
if structured_output_req and structured_output_req.grammar:
|
||||
request.status = RequestStatus.WAITING
|
||||
else:
|
||||
waiting_structured_output_req = self.waiting.popleft()
|
||||
waiting_for_fsm.appendleft(
|
||||
waiting_structured_output_req)
|
||||
continue
|
||||
|
||||
# Check that adding the request still respects the max_loras
|
||||
# constraint.
|
||||
if self.lora_config and request.lora_request:
|
||||
@ -281,6 +314,10 @@ class Scheduler:
|
||||
break
|
||||
|
||||
self.waiting.popleft()
|
||||
if request.use_structured_output:
|
||||
structured_output_request_ids[
|
||||
request.request_id] = req_index
|
||||
req_index += 1
|
||||
self.running.append(request)
|
||||
self.scheduled_req_ids.add(request.request_id)
|
||||
self.request_scheduled(request, scheduled_timestamp)
|
||||
@ -311,6 +348,10 @@ class Scheduler:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
encoder_budget = new_encoder_budget
|
||||
|
||||
# Put back any skipped requests at the head of the waiting queue
|
||||
if waiting_for_fsm:
|
||||
self.waiting.extendleft(waiting_for_fsm)
|
||||
|
||||
# Check if the scheduling constraints are satisfied.
|
||||
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
||||
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
|
||||
@ -331,6 +372,11 @@ class Scheduler:
|
||||
self.kv_cache_manager.get_num_common_prefix_blocks(
|
||||
any_request, len(self.running)))
|
||||
|
||||
grammar_bitmask = self.structured_output_manager.grammar_bitmask(
|
||||
self.requests,
|
||||
structured_output_request_ids,
|
||||
len(self.running),
|
||||
)
|
||||
# Construct the scheduler output.
|
||||
new_reqs_data = [
|
||||
NewRequestData.from_request(req,
|
||||
@ -369,6 +415,8 @@ class Scheduler:
|
||||
# the previous and the current steps.
|
||||
finished_req_ids=self.finished_req_ids,
|
||||
free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(),
|
||||
structured_output_request_ids=structured_output_request_ids,
|
||||
grammar_bitmask=grammar_bitmask,
|
||||
)
|
||||
|
||||
self.finished_req_ids = set()
|
||||
@ -381,7 +429,7 @@ class Scheduler:
|
||||
num_scheduled_spec_tokens: int,
|
||||
new_block_ids: list[int],
|
||||
resumed_from_preemption: bool,
|
||||
) -> "CachedRequestData":
|
||||
) -> CachedRequestData:
|
||||
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
|
||||
# them at each scheduling step.
|
||||
num_computed_tokens = request.num_computed_tokens
|
||||
@ -474,8 +522,8 @@ class Scheduler:
|
||||
|
||||
def update_from_output(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
model_runner_output: "ModelRunnerOutput",
|
||||
scheduler_output: SchedulerOutput,
|
||||
model_runner_output: ModelRunnerOutput,
|
||||
) -> EngineCoreOutputs:
|
||||
sampled_token_ids = model_runner_output.sampled_token_ids
|
||||
spec_token_ids = model_runner_output.spec_token_ids
|
||||
@ -565,6 +613,15 @@ class Scheduler:
|
||||
# the outer lists can be of length > 1.
|
||||
new_logprobs = logprobs.slice(req_index, req_index + 1)
|
||||
|
||||
if new_token_ids and request.use_structured_output:
|
||||
# NOTE: structured_output_request
|
||||
# should not be None if use_structured_output, we have
|
||||
# check above, so safe to ignore type warning
|
||||
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
|
||||
request.request_id,
|
||||
new_token_ids,
|
||||
)
|
||||
|
||||
# Transmit partial if chunked prefill & prompt logprobs is enabled
|
||||
if new_token_ids or prompt_logprobs_tensors is not None:
|
||||
# Add EngineCoreOutput for this Request.
|
||||
|
@ -1,9 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.multimodal.base import PlaceholderRange
|
||||
@ -17,20 +22,20 @@ class NewRequestData:
|
||||
req_id: str
|
||||
prompt_token_ids: list[int]
|
||||
prompt: Optional[str]
|
||||
mm_inputs: list["MultiModalKwargs"]
|
||||
mm_inputs: list[MultiModalKwargs]
|
||||
mm_hashes: list[str]
|
||||
mm_positions: list["PlaceholderRange"]
|
||||
sampling_params: "SamplingParams"
|
||||
mm_positions: list[PlaceholderRange]
|
||||
sampling_params: SamplingParams
|
||||
block_ids: list[int]
|
||||
num_computed_tokens: int
|
||||
lora_request: Optional["LoRARequest"]
|
||||
lora_request: Optional[LoRARequest]
|
||||
|
||||
@classmethod
|
||||
def from_request(
|
||||
cls,
|
||||
request: "Request",
|
||||
request: Request,
|
||||
block_ids: list[int],
|
||||
) -> "NewRequestData":
|
||||
) -> NewRequestData:
|
||||
return cls(
|
||||
req_id=request.request_id,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
@ -60,11 +65,11 @@ class CachedRequestData:
|
||||
@classmethod
|
||||
def from_request(
|
||||
cls,
|
||||
request: "Request",
|
||||
request: Request,
|
||||
resumed_from_preemption: bool,
|
||||
new_token_ids: list[int],
|
||||
new_block_ids: list[int],
|
||||
) -> "CachedRequestData":
|
||||
) -> CachedRequestData:
|
||||
return cls(
|
||||
req_id=request.request_id,
|
||||
resumed_from_preemption=resumed_from_preemption,
|
||||
@ -111,3 +116,9 @@ class SchedulerOutput:
|
||||
# list of (req_id, encoder_input_index) tuples.
|
||||
# Used to free the encoder cache.
|
||||
free_encoder_input_ids: list[tuple[str, int]]
|
||||
|
||||
# Dict of request ids to their index within the batch
|
||||
# for filling the next token bitmask
|
||||
structured_output_request_ids: dict[str, int]
|
||||
# the bitmask for the whole batch
|
||||
grammar_bitmask: Optional[npt.NDArray[np.int32]]
|
||||
|
@ -72,9 +72,7 @@ class AsyncLLM(EngineClient):
|
||||
|
||||
# Processor (converts Inputs --> EngineCoreRequests).
|
||||
self.processor = Processor(
|
||||
model_config=vllm_config.model_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
lora_config=vllm_config.lora_config,
|
||||
vllm_config=vllm_config,
|
||||
tokenizer=self.tokenizer,
|
||||
input_registry=input_registry,
|
||||
)
|
||||
|
@ -29,6 +29,7 @@ from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -61,6 +62,8 @@ class EngineCore:
|
||||
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
self.structured_output_manager = StructuredOutputManager(vllm_config)
|
||||
|
||||
# Setup scheduler.
|
||||
self.scheduler = Scheduler(
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
@ -69,6 +72,7 @@ class EngineCore:
|
||||
lora_config=vllm_config.lora_config,
|
||||
speculative_config=vllm_config.speculative_config,
|
||||
log_stats=self.log_stats,
|
||||
structured_output_manager=self.structured_output_manager,
|
||||
)
|
||||
|
||||
# Setup MM Input Mapper.
|
||||
@ -131,6 +135,9 @@ class EngineCore:
|
||||
request.mm_inputs, request.mm_hashes)
|
||||
|
||||
req = Request.from_engine_core_request(request)
|
||||
if req.use_structured_output:
|
||||
# Start grammar compilation asynchronously
|
||||
self.structured_output_manager.populate_cache(req)
|
||||
|
||||
self.scheduler.add_request(req)
|
||||
|
||||
@ -148,11 +155,24 @@ class EngineCore:
|
||||
|
||||
if not self.scheduler.has_unfinished_requests():
|
||||
return EngineCoreOutputs(
|
||||
outputs=[], scheduler_stats=self.scheduler.make_stats())
|
||||
outputs=[],
|
||||
scheduler_stats=self.scheduler.make_stats(),
|
||||
)
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
|
||||
# This case may occur when the only unfinished requests are
|
||||
# structured output requests where the grammar has not finished
|
||||
# compiling yet, so there's nothing to run.
|
||||
if scheduler_output.total_num_scheduled_tokens == 0:
|
||||
return EngineCoreOutputs(
|
||||
outputs=[],
|
||||
scheduler_stats=self.scheduler.make_stats(),
|
||||
)
|
||||
|
||||
output = self.model_executor.execute_model(scheduler_output)
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, output) # type: ignore
|
||||
|
||||
return engine_core_outputs
|
||||
|
||||
def step_with_batch_queue(self) -> Optional[EngineCoreOutputs]:
|
||||
|
@ -66,9 +66,7 @@ class LLMEngine:
|
||||
self.tokenizer.ping()
|
||||
|
||||
# Processor (convert Inputs --> EngineCoreRequests)
|
||||
self.processor = Processor(model_config=vllm_config.model_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
lora_config=vllm_config.lora_config,
|
||||
self.processor = Processor(vllm_config=vllm_config,
|
||||
tokenizer=self.tokenizer,
|
||||
input_registry=input_registry,
|
||||
mm_registry=mm_registry)
|
||||
|
@ -4,7 +4,7 @@ import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Optional, Union
|
||||
|
||||
from vllm.config import CacheConfig, LoRAConfig, ModelConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
||||
PromptType, SingletonInputsAdapter)
|
||||
from vllm.inputs.parse import is_encoder_decoder_inputs
|
||||
@ -19,39 +19,41 @@ from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.mm_input_cache import MMInputCacheClient
|
||||
from vllm.v1.structured_output.utils import validate_structured_output_request
|
||||
|
||||
|
||||
class Processor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
cache_config: CacheConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vllm_config: VllmConfig,
|
||||
tokenizer: BaseTokenizerGroup,
|
||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
):
|
||||
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.decoding_config = vllm_config.decoding_config
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.generation_config_fields = model_config.try_get_generation_config(
|
||||
)
|
||||
self.input_preprocessor = InputPreprocessor(model_config,
|
||||
self.generation_config_fields = (
|
||||
self.model_config.try_get_generation_config())
|
||||
self.input_preprocessor = InputPreprocessor(self.model_config,
|
||||
self.tokenizer,
|
||||
mm_registry)
|
||||
self.input_processor = input_registry.create_input_processor(
|
||||
model_config)
|
||||
self.model_config)
|
||||
|
||||
# Multi-modal (huggingface) input mapper
|
||||
self.mm_input_cache_client = MMInputCacheClient(model_config)
|
||||
self.mm_input_cache_client = MMInputCacheClient(self.model_config)
|
||||
|
||||
# Multi-modal hasher (for images)
|
||||
self.use_hash = (not model_config.disable_mm_preprocessor_cache) or \
|
||||
cache_config.enable_prefix_caching
|
||||
self.use_hash = (
|
||||
not self.model_config.disable_mm_preprocessor_cache) or \
|
||||
self.cache_config.enable_prefix_caching
|
||||
|
||||
def _validate_logprobs(
|
||||
self,
|
||||
@ -80,6 +82,8 @@ class Processor:
|
||||
self,
|
||||
params: SamplingParams,
|
||||
) -> None:
|
||||
self._validate_structured_output(params)
|
||||
|
||||
if params.allowed_token_ids is None:
|
||||
return
|
||||
if not params.allowed_token_ids:
|
||||
@ -125,6 +129,21 @@ class Processor:
|
||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||
"not enabled!")
|
||||
|
||||
def _validate_structured_output(self, params: SamplingParams) -> None:
|
||||
if not params.guided_decoding or not self.decoding_config:
|
||||
return
|
||||
if self.decoding_config.guided_decoding_backend != "xgrammar":
|
||||
raise ValueError(
|
||||
"Only xgrammar structured output is supported in V1.")
|
||||
if (params.guided_decoding.backend
|
||||
and params.guided_decoding.backend != 'xgrammar'):
|
||||
raise ValueError(
|
||||
"Only xgrammar structured output is supported in V1.")
|
||||
if self.vllm_config.speculative_config:
|
||||
raise ValueError("Structured output is not supported with "
|
||||
"speculative decoding.")
|
||||
validate_structured_output_request(params)
|
||||
|
||||
def process_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
|
@ -3,13 +3,15 @@
|
||||
import enum
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
|
||||
EngineCoreRequest, FinishReason)
|
||||
from vllm.v1.structured_output.request import StructuredOutputRequest
|
||||
from vllm.v1.utils import ConstantList
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
|
||||
@ -27,15 +29,19 @@ class Request:
|
||||
sampling_params: SamplingParams,
|
||||
eos_token_id: Optional[int],
|
||||
arrival_time: float,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
lora_request: Optional["LoRARequest"] = None,
|
||||
structured_output_request: Optional["StructuredOutputRequest"] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.sampling_params = sampling_params
|
||||
# Because of LoRA, the eos token id can be different for each request.
|
||||
self.eos_token_id = eos_token_id
|
||||
self.lora_request = lora_request
|
||||
self.structured_output_request = structured_output_request
|
||||
|
||||
self.status = RequestStatus.WAITING
|
||||
self.status = (RequestStatus.WAITING_FOR_FSM
|
||||
if sampling_params.guided_decoding is not None else
|
||||
RequestStatus.WAITING)
|
||||
self.events: list[EngineCoreEvent] = []
|
||||
self.stop_reason: Union[int, str, None] = None
|
||||
assert sampling_params.max_tokens is not None
|
||||
@ -78,6 +84,8 @@ class Request:
|
||||
eos_token_id=request.eos_token_id,
|
||||
arrival_time=request.arrival_time,
|
||||
lora_request=request.lora_request,
|
||||
structured_output_request=StructuredOutputRequest(
|
||||
sampling_params=request.sampling_params),
|
||||
)
|
||||
|
||||
def queued(self, timestamp: Optional[float] = None) -> None:
|
||||
@ -134,18 +142,23 @@ class Request:
|
||||
num_tokens = self.mm_positions[input_id]["length"]
|
||||
return num_tokens
|
||||
|
||||
@property
|
||||
def use_structured_output(self) -> bool:
|
||||
return self.sampling_params.guided_decoding is not None
|
||||
|
||||
|
||||
class RequestStatus(enum.IntEnum):
|
||||
"""Status of a request."""
|
||||
WAITING = 0
|
||||
RUNNING = 1
|
||||
PREEMPTED = 2
|
||||
# Note: anything after PREEMPTED (2) will be considered
|
||||
WAITING = enum.auto()
|
||||
WAITING_FOR_FSM = enum.auto()
|
||||
RUNNING = enum.auto()
|
||||
PREEMPTED = enum.auto()
|
||||
# Note: anything after PREEMPTED will be considered
|
||||
# as a finished status.
|
||||
FINISHED_STOPPED = 3
|
||||
FINISHED_LENGTH_CAPPED = 4
|
||||
FINISHED_ABORTED = 5
|
||||
FINISHED_IGNORED = 6
|
||||
FINISHED_STOPPED = enum.auto()
|
||||
FINISHED_LENGTH_CAPPED = enum.auto()
|
||||
FINISHED_ABORTED = enum.auto()
|
||||
FINISHED_IGNORED = enum.auto()
|
||||
|
||||
@staticmethod
|
||||
def is_finished(status: "RequestStatus") -> bool:
|
||||
|
152
vllm/v1/structured_output/__init__.py
Normal file
152
vllm/v1/structured_output/__init__.py
Normal file
@ -0,0 +1,152 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import multiprocessing
|
||||
from collections import OrderedDict
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.utils import LazyLoader
|
||||
from vllm.v1.structured_output.grammar import (Grammar, StructuredOutputKey,
|
||||
StructuredOutputOptions)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import xgrammar as xgr
|
||||
|
||||
from vllm.v1.request import Request
|
||||
else:
|
||||
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class StructuredOutputManager:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, max_cache_size: int = 500):
|
||||
tokenizer_group = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
parallel_config=vllm_config.parallel_config,
|
||||
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
|
||||
tokenizer_group.ping()
|
||||
self.vocab_size = vllm_config.model_config.get_vocab_size()
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
tokenizer = tokenizer_group.get_lora_tokenizer(None)
|
||||
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
|
||||
tokenizer, vocab_size=self.vocab_size)
|
||||
self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8)
|
||||
|
||||
self.max_cache_size = max_cache_size
|
||||
self.request_key_to_grammar: OrderedDict[StructuredOutputKey,
|
||||
Grammar] = OrderedDict()
|
||||
|
||||
# The default max_workers if not specified is the number of CPUs * 5,
|
||||
# which is way too high since these tasks are CPU-bound, not I/O bound.
|
||||
# We also know we would never dominate CPU usage with just grammar
|
||||
# compilation, so we set it to half the number of CPUs.
|
||||
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self._grammar_bitmask = xgr.allocate_token_bitmask(
|
||||
self.vllm_config.scheduler_config.max_num_seqs, self.vocab_size)
|
||||
|
||||
def __getitem__(self, key: StructuredOutputKey) -> Optional[Grammar]:
|
||||
# We need to pop and re-insert the grammar here for LRU cache
|
||||
# of request_key_to_grammar
|
||||
if key in self.request_key_to_grammar:
|
||||
# Move accessed item to the end (most recently used)
|
||||
value = self.request_key_to_grammar.pop(key)
|
||||
if value is not None:
|
||||
self.request_key_to_grammar[key] = value
|
||||
return value
|
||||
return None
|
||||
|
||||
def populate_cache(self, request: Request) -> None:
|
||||
if request.structured_output_request is None:
|
||||
return
|
||||
|
||||
grammar = self.request_key_to_grammar.get(
|
||||
request.structured_output_request.structured_output_key)
|
||||
if grammar:
|
||||
request.structured_output_request.grammar = copy.copy(grammar)
|
||||
return
|
||||
request.structured_output_request.grammar = self.cache(request)
|
||||
|
||||
def cache(self, request: Request):
|
||||
return self.executor.submit(self._executor_loop, request)
|
||||
|
||||
def _executor_loop(self, request: Request) -> Grammar:
|
||||
# NOTE: The structured_output_request should never be
|
||||
# None in this case, but mypy can't infer this
|
||||
# correctly, so we need to ignore the error here.
|
||||
key = request.structured_output_request.structured_output_key # type: ignore[union-attr]
|
||||
grammar = self.request_key_to_grammar.get(key)
|
||||
if grammar is not None:
|
||||
return copy.copy(grammar)
|
||||
grammar = self.initialize_grammar(key)
|
||||
# If cache is full, remove the least recently used item
|
||||
if len(self.request_key_to_grammar) >= self.max_cache_size:
|
||||
self.request_key_to_grammar.popitem(last=False)
|
||||
self.request_key_to_grammar[key] = grammar
|
||||
return copy.copy(grammar)
|
||||
|
||||
def initialize_grammar(self, key: StructuredOutputKey) -> Grammar:
|
||||
# Note that the request was validated in the engine core client,
|
||||
# so at this point we know it is a supported type of request.
|
||||
#
|
||||
# TODO: we still need to handle xgrammar compilation failures
|
||||
request_type, grammar_spec = key
|
||||
|
||||
if request_type == StructuredOutputOptions.JSON:
|
||||
# TODO -- allow any_whitespace to be configurable
|
||||
# pending merge of https://github.com/vllm-project/vllm/pull/12744
|
||||
ctx = self.compiler.compile_json_schema(grammar_spec,
|
||||
any_whitespace=False)
|
||||
elif request_type == StructuredOutputOptions.JSON_OBJECT:
|
||||
ctx = self.compiler.compile_builtin_json_grammar()
|
||||
elif request_type == StructuredOutputOptions.GRAMMAR:
|
||||
ctx = self.compiler.compile_grammar(grammar_spec)
|
||||
else:
|
||||
logger.error("Validation should have already occurred. "
|
||||
"Please file an issue.")
|
||||
raise ValueError(
|
||||
f"grammar is not of valid supported types. ({request_type!s})")
|
||||
|
||||
return Grammar(
|
||||
matcher=xgr.GrammarMatcher(ctx),
|
||||
vocab_size=self.vocab_size,
|
||||
ctx=ctx,
|
||||
)
|
||||
|
||||
def grammar_bitmask(
|
||||
self,
|
||||
requests: dict[str, Request],
|
||||
structured_output_request_ids: dict[str, int],
|
||||
batch_len: int,
|
||||
) -> Optional[npt.NDArray[np.int32]]:
|
||||
# Prepare the structured output bitmask for this batch.
|
||||
if not structured_output_request_ids:
|
||||
return None
|
||||
|
||||
# Fill the bitmask using the index of each request equal to its
|
||||
# position in the batch. Resize the bitmask down to the size of
|
||||
# the batch.
|
||||
bitmask_tensor = self._grammar_bitmask
|
||||
for req_id, batch_index in structured_output_request_ids.items():
|
||||
request = requests[req_id].structured_output_request
|
||||
assert request is not None and request.grammar is not None
|
||||
if not request.grammar.matcher.is_terminated():
|
||||
request.grammar.fill_bitmask(bitmask_tensor, batch_index)
|
||||
if batch_len < self._grammar_bitmask.shape[0]:
|
||||
bitmask_tensor = self._grammar_bitmask[:batch_len]
|
||||
|
||||
# After finishing with the xgrammar operations, we convert to
|
||||
# np.ndarray, because that is much more efficient for serialization
|
||||
# and deserialization when sending this to the GPU workers.
|
||||
return bitmask_tensor.numpy()
|
77
vllm/v1/structured_output/grammar.py
Normal file
77
vllm/v1/structured_output/grammar.py
Normal file
@ -0,0 +1,77 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import LazyLoader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import xgrammar as xgr
|
||||
else:
|
||||
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class StructuredOutputOptions(enum.Enum):
|
||||
JSON = enum.auto()
|
||||
JSON_OBJECT = enum.auto()
|
||||
REGEX = enum.auto()
|
||||
GRAMMAR = enum.auto()
|
||||
CHOICE = enum.auto()
|
||||
|
||||
|
||||
StructuredOutputKey = tuple[StructuredOutputOptions, str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Grammar:
|
||||
# NOTE: This would be a generic-enough class for
|
||||
# supporting different backends, in the future.
|
||||
# For now, just xgrammar.
|
||||
#
|
||||
# TODO: support max_rollback_tokens
|
||||
# https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string
|
||||
# for jump-forward decoding
|
||||
|
||||
vocab_size: int
|
||||
matcher: xgr.GrammarMatcher = field(hash=False)
|
||||
ctx: xgr.CompiledGrammar = field(hash=False)
|
||||
num_processed_tokens: int = field(default_factory=lambda: 0,
|
||||
repr=False,
|
||||
hash=False,
|
||||
init=False)
|
||||
|
||||
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
|
||||
"""Accepts a list of tokens and advances the FSM.
|
||||
|
||||
Returns True if the FSM was advanced successfully.
|
||||
Returns False if the FSM failed to advance.
|
||||
"""
|
||||
for token in tokens:
|
||||
if not self.matcher.accept_token(token):
|
||||
logger.error(
|
||||
"Failed to advance FSM for request %s "
|
||||
"for tokens %s. Please file an issue.", request_id, token)
|
||||
return False
|
||||
self.num_processed_tokens += 1
|
||||
return True
|
||||
|
||||
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> bool:
|
||||
return self.matcher.fill_next_token_bitmask(bitmask, idx)
|
||||
|
||||
def reset(self):
|
||||
self.num_processed_tokens = 0
|
||||
self.matcher.reset()
|
||||
|
||||
def __copy__(self):
|
||||
return Grammar(
|
||||
matcher=xgr.GrammarMatcher(self.ctx),
|
||||
vocab_size=self.vocab_size,
|
||||
ctx=self.ctx,
|
||||
)
|
71
vllm/v1/structured_output/request.py
Normal file
71
vllm/v1/structured_output/request.py
Normal file
@ -0,0 +1,71 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
import json
|
||||
from concurrent.futures import Future
|
||||
from concurrent.futures._base import TimeoutError
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.structured_output.grammar import (Grammar, StructuredOutputKey,
|
||||
StructuredOutputOptions)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class StructuredOutputRequest:
|
||||
|
||||
sampling_params: SamplingParams
|
||||
_grammar: Optional[Union[Future[Grammar], Grammar]] = None
|
||||
|
||||
def _check_grammar_completion(self) -> bool:
|
||||
# NOTE: We have to lazy import to gate circular imports
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
if isinstance(self._grammar, Future):
|
||||
try:
|
||||
# We will check whether the future is ready within 100 us
|
||||
self._grammar = self._grammar.result(timeout=0.0001)
|
||||
self.status = RequestStatus.WAITING
|
||||
except TimeoutError:
|
||||
return False
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_grammar_ready(self) -> bool:
|
||||
return self._check_grammar_completion()
|
||||
|
||||
@property
|
||||
def grammar(self) -> Optional[Grammar]:
|
||||
completed = self._check_grammar_completion()
|
||||
return cast(Optional[Grammar], self._grammar) if completed else None
|
||||
|
||||
@grammar.setter
|
||||
def grammar(self, grammar: Union[Grammar, Future[Grammar]]) -> None:
|
||||
self._grammar = grammar
|
||||
|
||||
@functools.cached_property
|
||||
def structured_output_key(self) -> StructuredOutputKey:
|
||||
params = self.sampling_params.guided_decoding
|
||||
assert params is not None, "params can't be None."
|
||||
if params.json is not None:
|
||||
if not isinstance(params.json, str):
|
||||
json_str = json.dumps(params.json)
|
||||
else:
|
||||
json_str = params.json
|
||||
return (StructuredOutputOptions.JSON, json_str)
|
||||
elif params.json_object:
|
||||
return (StructuredOutputOptions.JSON_OBJECT, "")
|
||||
elif params.regex is not None:
|
||||
return (StructuredOutputOptions.REGEX, params.regex)
|
||||
elif params.choice is not None:
|
||||
if not isinstance(params.choice, str):
|
||||
json_str = json.dumps(params.choice)
|
||||
else:
|
||||
json_str = params.choice
|
||||
return (StructuredOutputOptions.CHOICE, json_str)
|
||||
elif params.grammar is not None:
|
||||
return (StructuredOutputOptions.GRAMMAR, params.grammar)
|
||||
else:
|
||||
raise ValueError("No valid structured output parameter found")
|
295
vllm/v1/structured_output/utils.py
Normal file
295
vllm/v1/structured_output/utils.py
Normal file
@ -0,0 +1,295 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import LazyLoader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import xgrammar as xgr
|
||||
else:
|
||||
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||
|
||||
|
||||
def has_xgrammar_unsupported_json_features(schema: dict[str, Any]) -> bool:
|
||||
"""Check if JSON schema contains features unsupported by xgrammar."""
|
||||
|
||||
def check_object(obj: dict[str, Any]) -> bool:
|
||||
if not isinstance(obj, dict):
|
||||
return False
|
||||
|
||||
# Check for pattern restrictions
|
||||
if "pattern" in obj:
|
||||
return True
|
||||
|
||||
# Check for enum restrictions
|
||||
if "enum" in obj:
|
||||
return True
|
||||
|
||||
# Check for numeric ranges
|
||||
if obj.get("type") in ("integer", "number") and any(
|
||||
key in obj
|
||||
for key in ("minimum", "maximum", "exclusiveMinimum",
|
||||
"exclusiveMaximum", "multipleOf")):
|
||||
return True
|
||||
|
||||
# Check for array unsupported keywords
|
||||
if obj.get("type") == "array" and any(
|
||||
key in obj
|
||||
for key in ("uniqueItems", "contains", "minContains",
|
||||
"maxContains", "minItems", "maxItems")):
|
||||
return True
|
||||
|
||||
# Unsupported keywords for strings
|
||||
if obj.get("type") == "string" and any(
|
||||
key in obj for key in ("minLength", "maxLength", "format")):
|
||||
return True
|
||||
|
||||
# Unsupported keywords for objects
|
||||
if obj.get("type") == "object" and any(
|
||||
key in obj for key in ("minProperties", "maxProperties",
|
||||
"propertyNames", "patternProperties")):
|
||||
return True
|
||||
|
||||
# Recursively check all nested objects and arrays
|
||||
for value in obj.values():
|
||||
if isinstance(value, dict):
|
||||
if check_object(value):
|
||||
return True
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict) and check_object(item):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return check_object(schema)
|
||||
|
||||
|
||||
def grammar_is_likely_lark(grammar_str: str) -> bool:
|
||||
"""
|
||||
Check if grammar appears to use Lark syntax.
|
||||
|
||||
Args:
|
||||
grammar_str: Input grammar string
|
||||
|
||||
Returns:
|
||||
bool: True if grammar appears to be in Lark format, False otherwise
|
||||
|
||||
Examples:
|
||||
>>> grammar_is_likely_lark("rule: 'abc'")
|
||||
True
|
||||
>>> grammar_is_likely_lark("rule ::= 'abc'")
|
||||
False
|
||||
"""
|
||||
if not grammar_str or not isinstance(grammar_str, str):
|
||||
return False
|
||||
|
||||
for line in grammar_str.split('\n'):
|
||||
# Remove both comment styles
|
||||
line = re.sub(r'(#|//).*$', '', line).strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Look for EBNF rule definition
|
||||
if '::=' in line:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def convert_lark_to_ebnf(grammar_str: str) -> str:
|
||||
"""
|
||||
Convert a Lark grammar string to EBNF format.
|
||||
|
||||
EBNF reference:
|
||||
https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
|
||||
Lark grammar reference:
|
||||
https://lark-parser.readthedocs.io/en/latest/grammar.html
|
||||
|
||||
Args:
|
||||
grammar_str: Input grammar in Lark format
|
||||
|
||||
Returns:
|
||||
str: Converted grammar in EBNF format
|
||||
|
||||
Examples:
|
||||
>>> print(convert_lark_to_ebnf("rule: 'hello'"))
|
||||
root ::= rule
|
||||
rule ::= "hello"
|
||||
"""
|
||||
if not isinstance(grammar_str, str):
|
||||
raise ValueError(f"Grammar must be a string, got {type(grammar_str)}")
|
||||
if not grammar_str.strip():
|
||||
raise ValueError("Grammar string cannot be empty")
|
||||
|
||||
defined_rules = set()
|
||||
referenced_rules = set()
|
||||
output_lines = []
|
||||
|
||||
def clean_line(line: str) -> str:
|
||||
"""Remove comments and whitespace from line."""
|
||||
return re.sub(r'(#|//).*$', '', line).strip()
|
||||
|
||||
def check_quotes(text: str, rule_name: str, line_num: int) -> None:
|
||||
"""Validate quote matching in text."""
|
||||
if text.count("'") % 2 != 0 or text.count('"') % 2 != 0:
|
||||
raise ValueError(
|
||||
f"Mismatched quotes in {rule_name} on line {line_num}")
|
||||
|
||||
def extract_references(text: str) -> set:
|
||||
"""Extract rule references from text."""
|
||||
# Remove quoted strings and special characters
|
||||
text = re.sub(r'"[^"]*"', '', text)
|
||||
text = re.sub(r'[+*?()|\[\]{}]', ' ', text)
|
||||
return set(re.findall(r'\b[a-zA-Z_][a-zA-Z0-9_]*\b', text))
|
||||
|
||||
# First pass: Find root rule and validate rule definitions
|
||||
lines = [clean_line(line) for line in grammar_str.split('\n')]
|
||||
first_rule = None
|
||||
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
if not line or line.startswith('|'):
|
||||
continue
|
||||
|
||||
if ':' in line:
|
||||
try:
|
||||
name = line.split(':', 1)[0].strip().strip('?')
|
||||
defined_rules.add(name)
|
||||
if first_rule is None:
|
||||
first_rule = name
|
||||
if name == 'start':
|
||||
first_rule = 'start'
|
||||
except IndexError as e:
|
||||
raise ValueError(f"Invalid rule format on line {line_num}. "
|
||||
"Expected 'rule_name: definition'") from e
|
||||
|
||||
if not defined_rules:
|
||||
raise ValueError("No valid rules found in grammar")
|
||||
|
||||
# Add root rule
|
||||
output_lines.append(f"root ::= {first_rule}")
|
||||
|
||||
# Second pass: Process rule definitions and alternatives
|
||||
current_rule = None
|
||||
current_definition = []
|
||||
|
||||
for line_num, line in enumerate(lines, 1):
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
if ':' in line and not line.startswith('|'):
|
||||
# Save previous rule if exists
|
||||
if current_rule:
|
||||
output_lines.append(
|
||||
f"{current_rule} ::= {' | '.join(current_definition)}")
|
||||
|
||||
# Process new rule
|
||||
name, definition = line.split(':', 1)
|
||||
current_rule = name.strip().strip('?')
|
||||
|
||||
check_quotes(definition, f"rule '{current_rule}'", line_num)
|
||||
definition = re.sub(r"'([^']*)'", r'"\1"', definition)
|
||||
referenced_rules.update(extract_references(definition))
|
||||
current_definition = [definition.strip()]
|
||||
|
||||
elif line.startswith('|'):
|
||||
if not current_rule:
|
||||
raise ValueError(f"Alternative '|' on line {line_num} "
|
||||
"without a preceding rule definition")
|
||||
|
||||
alt_def = line[1:].strip()
|
||||
check_quotes(alt_def, f"alternative for rule '{current_rule}'",
|
||||
line_num)
|
||||
alt_def = re.sub(r"'([^']*)'", r'"\1"', alt_def)
|
||||
referenced_rules.update(extract_references(alt_def))
|
||||
current_definition.append(alt_def)
|
||||
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Error on line {line_num}: {str(e)}") from e
|
||||
|
||||
# Add final rule if exists
|
||||
if current_rule:
|
||||
output_lines.append(
|
||||
f"{current_rule} ::= {' | '.join(current_definition)}")
|
||||
|
||||
# Validate all rules are defined
|
||||
undefined_rules = referenced_rules - defined_rules - {'root'}
|
||||
if undefined_rules:
|
||||
raise ValueError("Referenced rules are not defined: "
|
||||
f"{', '.join(sorted(undefined_rules))}")
|
||||
|
||||
return '\n'.join(output_lines)
|
||||
|
||||
|
||||
def choice_as_grammar(choice: list[str]) -> str:
|
||||
|
||||
def escape_ebnf_string(s: str) -> str:
|
||||
"""Escape special characters in a EBNF string."""
|
||||
# Escape double quotes and backslashes
|
||||
return re.sub(r'(["\\])', r'\\\1', s)
|
||||
|
||||
escaped_choices = (escape_ebnf_string(c) for c in choice)
|
||||
grammar = ('root ::= ' + ' | '.join(f'"{c}"' for c in escaped_choices))
|
||||
return grammar
|
||||
|
||||
|
||||
def validate_structured_output_request(
|
||||
sampling_params: SamplingParams) -> None:
|
||||
"""Validate that the request is supported by structured output.
|
||||
|
||||
Raises ValueError if the request is not supported.
|
||||
"""
|
||||
if sampling_params.guided_decoding is None:
|
||||
return
|
||||
|
||||
gd_params = sampling_params.guided_decoding
|
||||
|
||||
if gd_params.regex:
|
||||
raise ValueError("Regex structured output is not supported.")
|
||||
|
||||
if gd_params.choice:
|
||||
choice_grammar = choice_as_grammar(gd_params.choice)
|
||||
try:
|
||||
xgr.Grammar.from_ebnf(choice_grammar)
|
||||
except Exception as err:
|
||||
raise ValueError("Failed to transform choices into a grammar: "
|
||||
"{err}") from err
|
||||
gd_params.choice = None
|
||||
gd_params.grammar = choice_grammar
|
||||
return
|
||||
|
||||
if gd_params.json:
|
||||
if isinstance(gd_params.json, str):
|
||||
try:
|
||||
schema = json.loads(gd_params.json)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError("Invalid JSON grammar specification.") from e
|
||||
else:
|
||||
schema = gd_params.json
|
||||
|
||||
if has_xgrammar_unsupported_json_features(schema):
|
||||
raise ValueError("The provided JSON schema contains features not "
|
||||
"supported by xgrammar.")
|
||||
return
|
||||
|
||||
if gd_params.grammar:
|
||||
if grammar_is_likely_lark(gd_params.grammar):
|
||||
# xgrammar supports EBNF grammars only
|
||||
try:
|
||||
gd_params.grammar = convert_lark_to_ebnf(gd_params.grammar)
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
"Failed to convert the grammar from Lark to EBNF. ") from e
|
||||
|
||||
# Test parsing EBNF grammar, possibly already converted from Lark
|
||||
try:
|
||||
# parse the grammar, but we aren't compiling it.
|
||||
xgr.Grammar.from_ebnf(gd_params.grammar)
|
||||
except Exception as e:
|
||||
raise ValueError("Invalid grammar specification.") from e
|
@ -25,7 +25,8 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
LayerBlockType, cdiv, is_pin_memory_available)
|
||||
LayerBlockType, LazyLoader, cdiv,
|
||||
is_pin_memory_available)
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.engine.mm_input_cache import MMInputCacheClient
|
||||
@ -40,7 +41,11 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import xgrammar as xgr
|
||||
|
||||
from vllm.v1.core.scheduler_output import SchedulerOutput
|
||||
else:
|
||||
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -860,6 +865,53 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
def apply_grammar_bitmask(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
logits: torch.Tensor,
|
||||
):
|
||||
# Serialization of np.ndarray is much more efficient than a tensor,
|
||||
# so we receive it in that format.
|
||||
grammar_bitmask = scheduler_output.grammar_bitmask
|
||||
if grammar_bitmask is None:
|
||||
return
|
||||
|
||||
# We receive the structured output bitmask from the scheduler, but the
|
||||
# indices of the requests in the batch may not match the indices of
|
||||
# the bitmask since the scheduler doesn't know how the gpu runner is
|
||||
# ordering the requests in the batch. We need to sort the bitmask to
|
||||
# match the order of the requests used here.
|
||||
struct_out_req_batch_indices: dict[str, int] = {}
|
||||
indices_match = True
|
||||
for req_id in self.input_batch.req_ids:
|
||||
mask_index = scheduler_output.structured_output_request_ids.get(
|
||||
req_id)
|
||||
if mask_index is None:
|
||||
# not a structured output request
|
||||
continue
|
||||
batch_index = self.input_batch.req_id_to_index[req_id]
|
||||
if batch_index != mask_index:
|
||||
indices_match = False
|
||||
struct_out_req_batch_indices[req_id] = batch_index
|
||||
|
||||
if not indices_match:
|
||||
# Sort the bitmask to match the order of the requests
|
||||
sorted_bitmask = np.zeros_like(grammar_bitmask)
|
||||
for req_id, batch_index in struct_out_req_batch_indices.items():
|
||||
orig_index = scheduler_output.structured_output_request_ids[
|
||||
req_id]
|
||||
sorted_bitmask[batch_index] = grammar_bitmask[orig_index]
|
||||
grammar_bitmask = sorted_bitmask
|
||||
|
||||
grammar_bitmask = torch.from_numpy(grammar_bitmask)
|
||||
|
||||
# TODO: compatibility with spec decode
|
||||
xgr.apply_token_bitmask_inplace(
|
||||
logits,
|
||||
grammar_bitmask.to(self.device, non_blocking=True),
|
||||
indices=list(struct_out_req_batch_indices.values()),
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
@ -945,6 +997,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
|
||||
# Apply structured output bitmasks if present
|
||||
if scheduler_output.grammar_bitmask is not None:
|
||||
self.apply_grammar_bitmask(scheduler_output, logits)
|
||||
|
||||
# Sample the next token and get logprobs if needed.
|
||||
sampling_metadata = self.input_batch.sampling_metadata
|
||||
if not self.use_spec_decode:
|
||||
|
Loading…
x
Reference in New Issue
Block a user