446 lines
15 KiB
Python
446 lines
15 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import re
|
|
from enum import Enum
|
|
from typing import Any
|
|
|
|
import jsonschema
|
|
import pytest
|
|
from pydantic import BaseModel
|
|
|
|
from vllm.entrypoints.llm import LLM
|
|
from vllm.outputs import RequestOutput
|
|
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
|
|
|
GUIDED_DECODING_BACKENDS_V1 = ["xgrammar", "guidance"]
|
|
MODELS_TO_TEST = [
|
|
"Qwen/Qwen2.5-1.5B-Instruct", "mistralai/Ministral-8B-Instruct-2410"
|
|
]
|
|
|
|
|
|
@pytest.mark.skip_global_cleanup
|
|
@pytest.mark.parametrize("guided_decoding_backend",
|
|
GUIDED_DECODING_BACKENDS_V1)
|
|
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
|
|
def test_guided_json_completion(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
sample_json_schema: dict[str, Any],
|
|
guided_decoding_backend: str,
|
|
model_name: str,
|
|
):
|
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
|
llm = LLM(model=model_name,
|
|
max_model_len=1024,
|
|
guided_decoding_backend=guided_decoding_backend)
|
|
sampling_params = SamplingParams(
|
|
temperature=1.0,
|
|
max_tokens=1000,
|
|
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
|
|
outputs = llm.generate(prompts=[
|
|
f"Give an example JSON for an employee profile "
|
|
f"that fits this schema: {sample_json_schema}"
|
|
] * 2,
|
|
sampling_params=sampling_params,
|
|
use_tqdm=True)
|
|
|
|
assert outputs is not None
|
|
|
|
for output in outputs:
|
|
assert output is not None
|
|
assert isinstance(output, RequestOutput)
|
|
prompt = output.prompt
|
|
|
|
generated_text = output.outputs[0].text
|
|
assert generated_text is not None
|
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
output_json = json.loads(generated_text)
|
|
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
|
|
|
|
|
@pytest.mark.skip_global_cleanup
|
|
@pytest.mark.parametrize("guided_decoding_backend",
|
|
GUIDED_DECODING_BACKENDS_V1)
|
|
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
|
|
def test_guided_json_completion_disable_any_whitespace(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
sample_json_schema: dict[str, Any],
|
|
guided_decoding_backend: str,
|
|
model_name: str,
|
|
):
|
|
if guided_decoding_backend != "xgrammar":
|
|
pytest.skip("disable-any-whitespace is only supported for xgrammar.")
|
|
guided_decoding_backend = 'xgrammar:disable-any-whitespace'
|
|
|
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
|
llm = LLM(model=model_name,
|
|
max_model_len=1024,
|
|
guided_decoding_backend=guided_decoding_backend)
|
|
sampling_params = SamplingParams(
|
|
temperature=1.0,
|
|
max_tokens=1000,
|
|
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
|
|
outputs = llm.generate(prompts=[
|
|
f"Give an example JSON for an employee profile "
|
|
f"that fits this schema: {sample_json_schema}"
|
|
] * 2,
|
|
sampling_params=sampling_params,
|
|
use_tqdm=True)
|
|
|
|
assert outputs is not None
|
|
|
|
for output in outputs:
|
|
assert output is not None
|
|
assert isinstance(output, RequestOutput)
|
|
prompt = output.prompt
|
|
|
|
generated_text = output.outputs[0].text
|
|
assert generated_text is not None
|
|
assert "\n" not in generated_text
|
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
output_json = json.loads(generated_text)
|
|
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
|
|
|
|
|
@pytest.mark.skip_global_cleanup
|
|
@pytest.mark.parametrize("guided_decoding_backend",
|
|
GUIDED_DECODING_BACKENDS_V1)
|
|
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
|
|
def test_guided_json_object(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
guided_decoding_backend: str,
|
|
model_name: str,
|
|
):
|
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
|
llm = LLM(model=model_name,
|
|
max_model_len=1024,
|
|
guided_decoding_backend=guided_decoding_backend)
|
|
sampling_params = SamplingParams(
|
|
temperature=1.0,
|
|
max_tokens=100,
|
|
n=2,
|
|
guided_decoding=GuidedDecodingParams(json_object=True))
|
|
|
|
outputs = llm.generate(
|
|
prompts=("Generate a JSON object with curly braces for a person with "
|
|
"name and age fields for John Smith who is 31 years old."),
|
|
sampling_params=sampling_params,
|
|
use_tqdm=True)
|
|
|
|
assert outputs is not None
|
|
for output in outputs:
|
|
assert output is not None
|
|
assert isinstance(output, RequestOutput)
|
|
|
|
for i in range(2):
|
|
generated_text = output.outputs[i].text
|
|
print(generated_text)
|
|
assert generated_text is not None
|
|
|
|
# Parse to verify it is valid JSON
|
|
parsed_json = json.loads(generated_text)
|
|
allowed_types: tuple[type, ...] = (dict, )
|
|
if guided_decoding_backend == "xgrammar":
|
|
# TODO - we are currently too permissive with xgrammar and
|
|
# allow # any valid json (typically comes back as a list or
|
|
# object). We can fix this by specifying a jsonschema of
|
|
# {"type": "object"}, # but we need this fix in a release
|
|
# first: https://github.com/mlc-ai/xgrammar/pull/264
|
|
allowed_types = (dict, list)
|
|
assert isinstance(parsed_json, allowed_types)
|
|
|
|
|
|
@pytest.mark.skip_global_cleanup
|
|
@pytest.mark.parametrize("guided_decoding_backend",
|
|
GUIDED_DECODING_BACKENDS_V1 + ["auto"])
|
|
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
|
|
def test_guided_json_unsupported_schema(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
unsupported_json_schema: dict[str, Any],
|
|
guided_decoding_backend: str,
|
|
model_name: str,
|
|
):
|
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
|
llm = LLM(model=model_name,
|
|
max_model_len=1024,
|
|
guided_decoding_backend=guided_decoding_backend)
|
|
sampling_params = SamplingParams(
|
|
temperature=1.0,
|
|
max_tokens=1000,
|
|
guided_decoding=GuidedDecodingParams(json=unsupported_json_schema))
|
|
if guided_decoding_backend == "xgrammar":
|
|
with pytest.raises(ValueError,
|
|
match="The provided JSON schema contains features "
|
|
"not supported by xgrammar."):
|
|
llm.generate(prompts=[
|
|
f"Give an example JSON for an employee profile "
|
|
f"that fits this schema: {unsupported_json_schema}"
|
|
] * 2,
|
|
sampling_params=sampling_params,
|
|
use_tqdm=True)
|
|
else:
|
|
# This should work for both "guidance" and "auto".
|
|
|
|
outputs = llm.generate(
|
|
prompts=("Give an example JSON object for a grade "
|
|
"that fits this schema: "
|
|
f"{unsupported_json_schema}"),
|
|
sampling_params=sampling_params,
|
|
use_tqdm=True)
|
|
assert outputs is not None
|
|
for output in outputs:
|
|
assert output is not None
|
|
assert isinstance(output, RequestOutput)
|
|
generated_text = output.outputs[0].text
|
|
assert generated_text is not None
|
|
print(generated_text)
|
|
|
|
# Parse to verify it is valid JSON
|
|
parsed_json = json.loads(generated_text)
|
|
assert isinstance(parsed_json, dict)
|
|
|
|
|
|
@pytest.mark.skip_global_cleanup
|
|
@pytest.mark.parametrize("guided_decoding_backend",
|
|
GUIDED_DECODING_BACKENDS_V1)
|
|
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
|
|
def test_guided_grammar_ebnf(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
sample_sql_ebnf: str,
|
|
guided_decoding_backend: str,
|
|
model_name: str,
|
|
):
|
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
|
llm = LLM(model=model_name,
|
|
max_model_len=1024,
|
|
guided_decoding_backend=guided_decoding_backend)
|
|
sampling_params = SamplingParams(
|
|
temperature=0.8,
|
|
top_p=0.95,
|
|
max_tokens=1000,
|
|
guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf))
|
|
outputs = llm.generate(
|
|
prompts=("Generate a sql statement that selects col_1 from "
|
|
"table_1 where it is equal to 1"),
|
|
sampling_params=sampling_params,
|
|
use_tqdm=True,
|
|
)
|
|
|
|
assert outputs is not None
|
|
for output in outputs:
|
|
assert output is not None
|
|
assert isinstance(output, RequestOutput)
|
|
prompt = output.prompt
|
|
|
|
generated_text = output.outputs[0].text
|
|
assert generated_text is not None
|
|
|
|
# remove spaces for comparison b/c we removed them in the grammar
|
|
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
|
|
" ", "")
|
|
|
|
assert generated_text.strip() == ground_truth
|
|
|
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
|
|
|
|
@pytest.mark.skip_global_cleanup
|
|
@pytest.mark.parametrize("guided_decoding_backend",
|
|
GUIDED_DECODING_BACKENDS_V1)
|
|
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
|
|
def test_guided_grammar_lark(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
sample_sql_lark: str,
|
|
guided_decoding_backend: str,
|
|
model_name: str,
|
|
):
|
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
|
llm = LLM(model=model_name,
|
|
max_model_len=1024,
|
|
guided_decoding_backend=guided_decoding_backend)
|
|
sampling_params = SamplingParams(
|
|
temperature=0.8,
|
|
top_p=0.95,
|
|
max_tokens=1000,
|
|
guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark))
|
|
outputs = llm.generate(
|
|
prompts=("Generate a sql statement that selects col_1 from "
|
|
"table_1 where it is equal to 1"),
|
|
sampling_params=sampling_params,
|
|
use_tqdm=True,
|
|
)
|
|
|
|
assert outputs is not None
|
|
for output in outputs:
|
|
assert output is not None
|
|
assert isinstance(output, RequestOutput)
|
|
prompt = output.prompt
|
|
|
|
generated_text = output.outputs[0].text
|
|
assert generated_text is not None
|
|
|
|
# use Lark to parse the output, and make sure it's a valid parse tree
|
|
from lark import Lark
|
|
parser = Lark(sample_sql_lark)
|
|
parser.parse(generated_text)
|
|
|
|
# remove spaces for comparison b/c we removed them in the grammar
|
|
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
|
|
" ", "")
|
|
|
|
assert generated_text.strip() == ground_truth
|
|
|
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
|
|
|
|
@pytest.mark.skip_global_cleanup
|
|
@pytest.mark.parametrize("guided_decoding_backend",
|
|
GUIDED_DECODING_BACKENDS_V1)
|
|
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
|
|
def test_guided_grammar_ebnf_invalid(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
guided_decoding_backend: str,
|
|
model_name: str,
|
|
):
|
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
|
llm = LLM(model=model_name,
|
|
max_model_len=1024,
|
|
guided_decoding_backend=guided_decoding_backend)
|
|
sampling_params = SamplingParams(
|
|
temperature=0.8,
|
|
top_p=0.95,
|
|
max_tokens=1000,
|
|
guided_decoding=GuidedDecodingParams(grammar="not a grammar"))
|
|
with pytest.raises(ValueError, match="Failed to convert the grammar "):
|
|
llm.generate(
|
|
prompts=("Generate a sql statement that selects col_1 from "
|
|
"table_1 where it is equal to 1"),
|
|
sampling_params=sampling_params,
|
|
use_tqdm=True,
|
|
)
|
|
|
|
|
|
@pytest.mark.skip_global_cleanup
|
|
@pytest.mark.parametrize("guided_decoding_backend",
|
|
GUIDED_DECODING_BACKENDS_V1)
|
|
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
|
|
def test_guided_regex(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
sample_regex: str,
|
|
guided_decoding_backend: str,
|
|
model_name: str,
|
|
):
|
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
|
llm = LLM(model=model_name,
|
|
max_model_len=1024,
|
|
guided_decoding_backend=guided_decoding_backend)
|
|
sampling_params = SamplingParams(
|
|
temperature=0.8,
|
|
top_p=0.95,
|
|
guided_decoding=GuidedDecodingParams(regex=sample_regex))
|
|
outputs = llm.generate(
|
|
prompts=[
|
|
f"Give an example IPv4 address with this regex: {sample_regex}"
|
|
] * 2,
|
|
sampling_params=sampling_params,
|
|
use_tqdm=True,
|
|
)
|
|
|
|
assert outputs is not None
|
|
for output in outputs:
|
|
assert output is not None
|
|
assert isinstance(output, RequestOutput)
|
|
prompt = output.prompt
|
|
generated_text = output.outputs[0].text
|
|
print(generated_text)
|
|
assert generated_text is not None
|
|
assert re.fullmatch(sample_regex, generated_text) is not None
|
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
|
|
|
|
@pytest.mark.skip_global_cleanup
|
|
@pytest.mark.parametrize("guided_decoding_backend",
|
|
GUIDED_DECODING_BACKENDS_V1)
|
|
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
|
|
def test_guided_choice_completion(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
sample_guided_choice: str,
|
|
guided_decoding_backend: str,
|
|
model_name: str,
|
|
):
|
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
|
llm = LLM(model=model_name,
|
|
max_model_len=1024,
|
|
guided_decoding_backend=guided_decoding_backend)
|
|
sampling_params = SamplingParams(
|
|
temperature=0.8,
|
|
top_p=0.95,
|
|
guided_decoding=GuidedDecodingParams(choice=sample_guided_choice))
|
|
outputs = llm.generate(
|
|
prompts="The best language for type-safe systems programming is ",
|
|
sampling_params=sampling_params,
|
|
use_tqdm=True)
|
|
assert outputs is not None
|
|
for output in outputs:
|
|
assert output is not None
|
|
assert isinstance(output, RequestOutput)
|
|
prompt = output.prompt
|
|
generated_text = output.outputs[0].text
|
|
print(generated_text)
|
|
assert generated_text is not None
|
|
assert generated_text in sample_guided_choice
|
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
|
|
|
|
class CarType(str, Enum):
|
|
sedan = "sedan"
|
|
suv = "SUV"
|
|
truck = "Truck"
|
|
coupe = "Coupe"
|
|
|
|
|
|
class CarDescription(BaseModel):
|
|
brand: str
|
|
model: str
|
|
car_type: CarType
|
|
|
|
|
|
@pytest.mark.skip_global_cleanup
|
|
@pytest.mark.parametrize("guided_decoding_backend",
|
|
GUIDED_DECODING_BACKENDS_V1)
|
|
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
|
|
def test_guided_json_completion_with_enum(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
guided_decoding_backend: str,
|
|
model_name: str,
|
|
):
|
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
|
llm = LLM(model=model_name,
|
|
max_model_len=1024,
|
|
guided_decoding_backend=guided_decoding_backend)
|
|
json_schema = CarDescription.model_json_schema()
|
|
sampling_params = SamplingParams(
|
|
temperature=1.0,
|
|
max_tokens=1000,
|
|
guided_decoding=GuidedDecodingParams(json=json_schema))
|
|
outputs = llm.generate(
|
|
prompts="Generate a JSON with the brand, model and car_type of"
|
|
"the most iconic car from the 90's",
|
|
sampling_params=sampling_params,
|
|
use_tqdm=True)
|
|
|
|
assert outputs is not None
|
|
|
|
for output in outputs:
|
|
assert output is not None
|
|
assert isinstance(output, RequestOutput)
|
|
prompt = output.prompt
|
|
|
|
generated_text = output.outputs[0].text
|
|
assert generated_text is not None
|
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
output_json = json.loads(generated_text)
|
|
jsonschema.validate(instance=output_json, schema=json_schema)
|