[benchmarks] Add option to use unique jsonschema for each request (#14457)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
parent
8d5aa466fb
commit
9085aabd62
@ -24,11 +24,13 @@ On the client side, run:
|
|||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import copy
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -109,24 +111,43 @@ class SampleRequest:
|
|||||||
|
|
||||||
def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
def sample_requests(tokenizer: PreTrainedTokenizerBase,
|
||||||
args: argparse.Namespace) -> list[SampleRequest]:
|
args: argparse.Namespace) -> list[SampleRequest]:
|
||||||
if args.dataset == 'json':
|
if args.dataset == 'json' or args.dataset == 'json-unique':
|
||||||
if args.json_schema_path is None:
|
if args.json_schema_path is None:
|
||||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
args.json_schema_path = os.path.join(dir_path,
|
args.json_schema_path = os.path.join(dir_path,
|
||||||
"structured_schemas",
|
"structured_schemas",
|
||||||
"structured_schema_1.json")
|
"structured_schema_1.json")
|
||||||
|
json_schemas = []
|
||||||
with open(args.json_schema_path) as f:
|
with open(args.json_schema_path) as f:
|
||||||
schema = json.load(f)
|
schema = json.load(f)
|
||||||
prompt = f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501
|
|
||||||
input_len = len(tokenizer(prompt).input_ids)
|
if args.dataset == 'json-unique':
|
||||||
print(f"Input length of the prompt: {input_len} tokens")
|
json_schemas = [
|
||||||
|
copy.deepcopy(schema) for _ in range(args.num_prompts)
|
||||||
|
]
|
||||||
|
for i in range(len(json_schemas)):
|
||||||
|
json_schemas[i]["properties"][
|
||||||
|
f"__optional_field_{uuid.uuid4()}"] = {
|
||||||
|
"type":
|
||||||
|
"string",
|
||||||
|
"description":
|
||||||
|
"An unique optional field to avoid cached schemas"
|
||||||
|
}
|
||||||
|
|
||||||
|
def gen_prompt(index: int):
|
||||||
|
schema = json_schemas[index % len(json_schemas)]
|
||||||
|
return f"Generate an example of a user profile given the following schema: {json.dumps(schema)}" # noqa: E501
|
||||||
|
|
||||||
|
def get_schema(index: int):
|
||||||
|
return json_schemas[index % len(json_schemas)]
|
||||||
|
|
||||||
requests = [
|
requests = [
|
||||||
SampleRequest(prompt=prompt,
|
SampleRequest(prompt=gen_prompt(i),
|
||||||
prompt_len=input_len,
|
prompt_len=len(tokenizer(gen_prompt(i)).input_ids),
|
||||||
expected_output_len=args.output_len,
|
expected_output_len=args.output_len,
|
||||||
schema=schema,
|
schema=get_schema(i),
|
||||||
structure_type=args.structure_type)
|
structure_type=args.structure_type)
|
||||||
for _ in range(args.num_prompts)
|
for i in range(args.num_prompts)
|
||||||
]
|
]
|
||||||
|
|
||||||
elif args.dataset == "grammar":
|
elif args.dataset == "grammar":
|
||||||
@ -821,10 +842,12 @@ if __name__ == "__main__":
|
|||||||
default="/v1/completions",
|
default="/v1/completions",
|
||||||
help="API endpoint.",
|
help="API endpoint.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--dataset",
|
||||||
"--dataset",
|
|
||||||
default='json',
|
default='json',
|
||||||
choices=['json', 'grammar', 'regex', 'choice', 'xgrammar_bench'])
|
choices=[
|
||||||
|
'json', 'json-unique', 'grammar', 'regex',
|
||||||
|
'choice', 'xgrammar_bench'
|
||||||
|
])
|
||||||
parser.add_argument("--json_schema_path",
|
parser.add_argument("--json_schema_path",
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
@ -966,9 +989,10 @@ if __name__ == "__main__":
|
|||||||
type=float,
|
type=float,
|
||||||
default=1.0,
|
default=1.0,
|
||||||
help="Ratio of Structured Outputs requests")
|
help="Ratio of Structured Outputs requests")
|
||||||
parser.add_argument("--structured-output-backend",
|
parser.add_argument(
|
||||||
|
"--structured-output-backend",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["outlines", "lm-format-enforcer", "xgrammar"],
|
choices=["outlines", "lm-format-enforcer", "xgrammar", "json-unique"],
|
||||||
default="xgrammar",
|
default="xgrammar",
|
||||||
help="Backend to use for structured outputs")
|
help="Backend to use for structured outputs")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user