diff --git a/benchmarks/benchmark_serving_structured_output.py b/benchmarks/benchmark_serving_structured_output.py index 3d43e045..dccef9d9 100644 --- a/benchmarks/benchmark_serving_structured_output.py +++ b/benchmarks/benchmark_serving_structured_output.py @@ -24,11 +24,13 @@ On the client side, run: """ import argparse import asyncio +import copy import dataclasses import json import os import random import time +import uuid import warnings from collections.abc import AsyncGenerator from dataclasses import dataclass @@ -109,24 +111,43 @@ class SampleRequest: def sample_requests(tokenizer: PreTrainedTokenizerBase, args: argparse.Namespace) -> list[SampleRequest]: - if args.dataset == 'json': + if args.dataset == 'json' or args.dataset == 'json-unique': if args.json_schema_path is None: dir_path = os.path.dirname(os.path.realpath(__file__)) args.json_schema_path = os.path.join(dir_path, "structured_schemas", "structured_schema_1.json") + json_schemas = [] with open(args.json_schema_path) as f: schema = json.load(f) - prompt = f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501 - input_len = len(tokenizer(prompt).input_ids) - print(f"Input length of the prompt: {input_len} tokens") + + if args.dataset == 'json-unique': + json_schemas = [ + copy.deepcopy(schema) for _ in range(args.num_prompts) + ] + for i in range(len(json_schemas)): + json_schemas[i]["properties"][ + f"__optional_field_{uuid.uuid4()}"] = { + "type": + "string", + "description": + "An unique optional field to avoid cached schemas" + } + + def gen_prompt(index: int): + schema = json_schemas[index % len(json_schemas)] + return f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501 + + def get_schema(index: int): + return json_schemas[index % len(json_schemas)] + requests = [ - SampleRequest(prompt=prompt, - prompt_len=input_len, + SampleRequest(prompt=gen_prompt(i), + prompt_len=len(tokenizer(gen_prompt(i)).input_ids), expected_output_len=args.output_len, - schema=schema, + schema=get_schema(i), structure_type=args.structure_type) - for _ in range(args.num_prompts) + for i in range(args.num_prompts) ] elif args.dataset == "grammar": @@ -821,10 +842,12 @@ if __name__ == "__main__": default="/v1/completions", help="API endpoint.", ) - parser.add_argument( - "--dataset", - default='json', - choices=['json', 'grammar', 'regex', 'choice', 'xgrammar_bench']) + parser.add_argument("--dataset", + default='json', + choices=[ + 'json', 'json-unique', 'grammar', 'regex', + 'choice', 'xgrammar_bench' + ]) parser.add_argument("--json_schema_path", type=str, default=None, @@ -966,11 +989,12 @@ if __name__ == "__main__": type=float, default=1.0, help="Ratio of Structured Outputs requests") - parser.add_argument("--structured-output-backend", - type=str, - choices=["outlines", "lm-format-enforcer", "xgrammar"], - default="xgrammar", - help="Backend to use for structured outputs") + parser.add_argument( + "--structured-output-backend", + type=str, + choices=["outlines", "lm-format-enforcer", "xgrammar", "json-unique"], + default="xgrammar", + help="Backend to use for structured outputs") args = parser.parse_args() main(args)