2025-03-07 10:19:11 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
2025-03-11 22:40:09 -04:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2025-03-07 10:19:11 -05:00
|
|
|
import json
|
2025-03-11 11:03:44 -04:00
|
|
|
import re
|
2025-03-11 22:40:09 -04:00
|
|
|
from typing import Any
|
2025-03-07 10:19:11 -05:00
|
|
|
|
|
|
|
import jsonschema
|
|
|
|
import pytest
|
|
|
|
|
|
|
|
from vllm.entrypoints.llm import LLM
|
|
|
|
from vllm.outputs import RequestOutput
|
|
|
|
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
|
|
|
|
|
|
|
GUIDED_DECODING_BACKENDS_V1 = ["xgrammar"]
|
|
|
|
|
|
|
|
|
2025-03-11 22:40:09 -04:00
|
|
|
@pytest.fixture
|
|
|
|
def model_name():
|
|
|
|
return [
|
|
|
|
"Qwen/Qwen2.5-1.5B-Instruct", "mistralai/Ministral-8B-Instruct-2410"
|
|
|
|
]
|
|
|
|
|
|
|
|
|
2025-03-07 10:19:11 -05:00
|
|
|
@pytest.mark.skip_global_cleanup
|
|
|
|
@pytest.mark.parametrize("guided_decoding_backend",
|
|
|
|
GUIDED_DECODING_BACKENDS_V1)
|
2025-03-11 22:40:09 -04:00
|
|
|
def test_guided_json_completion(
|
|
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
|
|
sample_json_schema: dict[str, Any],
|
|
|
|
guided_decoding_backend: str,
|
|
|
|
model_name: str,
|
|
|
|
):
|
2025-03-07 10:19:11 -05:00
|
|
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
2025-03-11 22:40:09 -04:00
|
|
|
llm = LLM(model=model_name, max_model_len=1024)
|
2025-03-07 10:19:11 -05:00
|
|
|
sampling_params = SamplingParams(temperature=1.0,
|
|
|
|
max_tokens=1000,
|
|
|
|
guided_decoding=GuidedDecodingParams(
|
|
|
|
json=sample_json_schema,
|
|
|
|
backend=guided_decoding_backend))
|
|
|
|
outputs = llm.generate(prompts=[
|
|
|
|
f"Give an example JSON for an employee profile "
|
|
|
|
f"that fits this schema: {sample_json_schema}"
|
|
|
|
] * 2,
|
|
|
|
sampling_params=sampling_params,
|
|
|
|
use_tqdm=True)
|
|
|
|
|
|
|
|
assert outputs is not None
|
|
|
|
|
|
|
|
for output in outputs:
|
|
|
|
assert output is not None
|
|
|
|
assert isinstance(output, RequestOutput)
|
|
|
|
prompt = output.prompt
|
|
|
|
|
|
|
|
generated_text = output.outputs[0].text
|
|
|
|
assert generated_text is not None
|
|
|
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
|
|
output_json = json.loads(generated_text)
|
|
|
|
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skip_global_cleanup
|
|
|
|
@pytest.mark.parametrize("guided_decoding_backend",
|
|
|
|
GUIDED_DECODING_BACKENDS_V1)
|
2025-03-11 22:40:09 -04:00
|
|
|
def test_guided_json_object(
|
|
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
|
|
guided_decoding_backend: str,
|
|
|
|
model_name: str,
|
|
|
|
):
|
2025-03-07 10:19:11 -05:00
|
|
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
2025-03-11 22:40:09 -04:00
|
|
|
llm = LLM(model=model_name, max_model_len=1024)
|
2025-03-07 10:19:11 -05:00
|
|
|
sampling_params = SamplingParams(temperature=1.0,
|
|
|
|
max_tokens=100,
|
|
|
|
n=2,
|
|
|
|
guided_decoding=GuidedDecodingParams(
|
|
|
|
json_object=True,
|
|
|
|
backend=guided_decoding_backend))
|
|
|
|
|
|
|
|
outputs = llm.generate(
|
|
|
|
prompts=("Generate a JSON object with curly braces for a person with "
|
|
|
|
"name and age fields for John Smith who is 31 years old."),
|
|
|
|
sampling_params=sampling_params,
|
|
|
|
use_tqdm=True)
|
|
|
|
|
|
|
|
assert outputs is not None
|
|
|
|
for output in outputs:
|
|
|
|
assert output is not None
|
|
|
|
assert isinstance(output, RequestOutput)
|
|
|
|
|
|
|
|
for i in range(2):
|
|
|
|
generated_text = output.outputs[i].text
|
|
|
|
print(generated_text)
|
|
|
|
assert generated_text is not None
|
|
|
|
|
|
|
|
# Parse to verify it is valid JSON
|
|
|
|
parsed_json = json.loads(generated_text)
|
|
|
|
assert isinstance(parsed_json, dict)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skip_global_cleanup
|
|
|
|
@pytest.mark.parametrize("guided_decoding_backend",
|
|
|
|
GUIDED_DECODING_BACKENDS_V1)
|
2025-03-11 22:40:09 -04:00
|
|
|
def test_guided_json_unsupported_schema(
|
|
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
|
|
unsupported_json_schema: dict[str, Any],
|
|
|
|
guided_decoding_backend: str,
|
|
|
|
model_name: str,
|
|
|
|
):
|
2025-03-07 10:19:11 -05:00
|
|
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
2025-03-11 22:40:09 -04:00
|
|
|
llm = LLM(model=model_name, max_model_len=1024)
|
2025-03-07 10:19:11 -05:00
|
|
|
sampling_params = SamplingParams(temperature=1.0,
|
|
|
|
max_tokens=1000,
|
|
|
|
guided_decoding=GuidedDecodingParams(
|
|
|
|
json=unsupported_json_schema,
|
|
|
|
backend=guided_decoding_backend))
|
|
|
|
with pytest.raises(ValueError,
|
|
|
|
match="The provided JSON schema contains features "
|
|
|
|
"not supported by xgrammar."):
|
|
|
|
llm.generate(prompts=[
|
|
|
|
f"Give an example JSON for an employee profile "
|
|
|
|
f"that fits this schema: {unsupported_json_schema}"
|
|
|
|
] * 2,
|
|
|
|
sampling_params=sampling_params,
|
|
|
|
use_tqdm=True)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skip_global_cleanup
|
|
|
|
@pytest.mark.parametrize("guided_decoding_backend",
|
|
|
|
GUIDED_DECODING_BACKENDS_V1)
|
2025-03-11 22:40:09 -04:00
|
|
|
def test_guided_grammar_ebnf(
|
|
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
|
|
sample_sql_ebnf: str,
|
|
|
|
guided_decoding_backend: str,
|
|
|
|
model_name: str,
|
|
|
|
):
|
2025-03-07 10:19:11 -05:00
|
|
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
2025-03-11 22:40:09 -04:00
|
|
|
llm = LLM(model=model_name, max_model_len=1024)
|
2025-03-07 10:19:11 -05:00
|
|
|
sampling_params = SamplingParams(temperature=0.8,
|
|
|
|
top_p=0.95,
|
|
|
|
max_tokens=1000,
|
|
|
|
guided_decoding=GuidedDecodingParams(
|
|
|
|
grammar=sample_sql_ebnf,
|
|
|
|
backend=guided_decoding_backend))
|
|
|
|
outputs = llm.generate(
|
|
|
|
prompts=("Generate a sql statement that selects col_1 from "
|
|
|
|
"table_1 where it is equal to 1"),
|
|
|
|
sampling_params=sampling_params,
|
|
|
|
use_tqdm=True,
|
|
|
|
)
|
|
|
|
|
|
|
|
assert outputs is not None
|
|
|
|
for output in outputs:
|
|
|
|
assert output is not None
|
|
|
|
assert isinstance(output, RequestOutput)
|
|
|
|
prompt = output.prompt
|
|
|
|
|
|
|
|
generated_text = output.outputs[0].text
|
|
|
|
assert generated_text is not None
|
|
|
|
|
|
|
|
# remove spaces for comparison b/c we removed them in the grammar
|
|
|
|
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
|
|
|
|
" ", "")
|
|
|
|
|
|
|
|
assert generated_text.strip() == ground_truth
|
|
|
|
|
|
|
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skip_global_cleanup
|
|
|
|
@pytest.mark.parametrize("guided_decoding_backend",
|
|
|
|
GUIDED_DECODING_BACKENDS_V1)
|
2025-03-11 22:40:09 -04:00
|
|
|
def test_guided_grammar_lark(
|
|
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
|
|
sample_sql_lark: str,
|
|
|
|
guided_decoding_backend: str,
|
|
|
|
model_name: str,
|
|
|
|
):
|
2025-03-07 10:19:11 -05:00
|
|
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
2025-03-11 22:40:09 -04:00
|
|
|
llm = LLM(model=model_name, max_model_len=1024)
|
2025-03-07 10:19:11 -05:00
|
|
|
sampling_params = SamplingParams(temperature=0.8,
|
|
|
|
top_p=0.95,
|
|
|
|
max_tokens=1000,
|
|
|
|
guided_decoding=GuidedDecodingParams(
|
|
|
|
grammar=sample_sql_lark,
|
|
|
|
backend=guided_decoding_backend))
|
|
|
|
outputs = llm.generate(
|
|
|
|
prompts=("Generate a sql statement that selects col_1 from "
|
|
|
|
"table_1 where it is equal to 1"),
|
|
|
|
sampling_params=sampling_params,
|
|
|
|
use_tqdm=True,
|
|
|
|
)
|
|
|
|
|
|
|
|
assert outputs is not None
|
|
|
|
for output in outputs:
|
|
|
|
assert output is not None
|
|
|
|
assert isinstance(output, RequestOutput)
|
|
|
|
prompt = output.prompt
|
|
|
|
|
|
|
|
generated_text = output.outputs[0].text
|
|
|
|
assert generated_text is not None
|
|
|
|
|
|
|
|
# use Lark to parse the output, and make sure it's a valid parse tree
|
|
|
|
from lark import Lark
|
|
|
|
parser = Lark(sample_sql_lark)
|
|
|
|
parser.parse(generated_text)
|
|
|
|
|
|
|
|
# remove spaces for comparison b/c we removed them in the grammar
|
|
|
|
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(
|
|
|
|
" ", "")
|
|
|
|
|
|
|
|
assert generated_text.strip() == ground_truth
|
|
|
|
|
|
|
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skip_global_cleanup
|
|
|
|
@pytest.mark.parametrize("guided_decoding_backend",
|
|
|
|
GUIDED_DECODING_BACKENDS_V1)
|
2025-03-11 22:40:09 -04:00
|
|
|
def test_guided_grammar_ebnf_invalid(
|
|
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
|
|
guided_decoding_backend: str,
|
|
|
|
model_name: str,
|
|
|
|
):
|
2025-03-07 10:19:11 -05:00
|
|
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
2025-03-11 22:40:09 -04:00
|
|
|
llm = LLM(model=model_name, max_model_len=1024)
|
2025-03-07 10:19:11 -05:00
|
|
|
sampling_params = SamplingParams(temperature=0.8,
|
|
|
|
top_p=0.95,
|
|
|
|
max_tokens=1000,
|
|
|
|
guided_decoding=GuidedDecodingParams(
|
|
|
|
grammar="not a grammar",
|
|
|
|
backend=guided_decoding_backend))
|
|
|
|
with pytest.raises(ValueError,
|
|
|
|
match="Failed to convert the grammar "
|
|
|
|
"from Lark to EBNF."):
|
|
|
|
llm.generate(
|
|
|
|
prompts=("Generate a sql statement that selects col_1 from "
|
|
|
|
"table_1 where it is equal to 1"),
|
|
|
|
sampling_params=sampling_params,
|
|
|
|
use_tqdm=True,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skip_global_cleanup
|
|
|
|
@pytest.mark.parametrize("guided_decoding_backend",
|
|
|
|
GUIDED_DECODING_BACKENDS_V1)
|
2025-03-11 22:40:09 -04:00
|
|
|
def test_guided_regex(
|
|
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
|
|
sample_regex: str,
|
|
|
|
guided_decoding_backend: str,
|
|
|
|
model_name: str,
|
|
|
|
):
|
2025-03-07 10:19:11 -05:00
|
|
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
2025-03-11 22:40:09 -04:00
|
|
|
llm = LLM(model=model_name, max_model_len=1024)
|
2025-03-07 10:19:11 -05:00
|
|
|
sampling_params = SamplingParams(temperature=0.8,
|
|
|
|
top_p=0.95,
|
|
|
|
guided_decoding=GuidedDecodingParams(
|
|
|
|
regex=sample_regex,
|
|
|
|
backend=guided_decoding_backend))
|
2025-03-11 11:03:44 -04:00
|
|
|
outputs = llm.generate(
|
|
|
|
prompts=[
|
2025-03-07 10:19:11 -05:00
|
|
|
f"Give an example IPv4 address with this regex: {sample_regex}"
|
|
|
|
] * 2,
|
2025-03-11 11:03:44 -04:00
|
|
|
sampling_params=sampling_params,
|
|
|
|
use_tqdm=True,
|
|
|
|
)
|
2025-03-07 10:19:11 -05:00
|
|
|
|
2025-03-11 11:03:44 -04:00
|
|
|
assert outputs is not None
|
|
|
|
for output in outputs:
|
|
|
|
assert output is not None
|
|
|
|
assert isinstance(output, RequestOutput)
|
|
|
|
prompt = output.prompt
|
|
|
|
generated_text = output.outputs[0].text
|
|
|
|
print(generated_text)
|
|
|
|
assert generated_text is not None
|
|
|
|
assert re.fullmatch(sample_regex, generated_text) is not None
|
|
|
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
2025-03-07 10:19:11 -05:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skip_global_cleanup
|
|
|
|
@pytest.mark.parametrize("guided_decoding_backend",
|
|
|
|
GUIDED_DECODING_BACKENDS_V1)
|
2025-03-11 22:40:09 -04:00
|
|
|
def test_guided_choice_completion(
|
|
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
|
|
sample_guided_choice: str,
|
|
|
|
guided_decoding_backend: str,
|
|
|
|
model_name: str,
|
|
|
|
):
|
2025-03-07 10:19:11 -05:00
|
|
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
2025-03-11 22:40:09 -04:00
|
|
|
llm = LLM(model=model_name, max_model_len=1024)
|
2025-03-07 10:19:11 -05:00
|
|
|
sampling_params = SamplingParams(temperature=0.8,
|
|
|
|
top_p=0.95,
|
|
|
|
guided_decoding=GuidedDecodingParams(
|
|
|
|
choice=sample_guided_choice,
|
|
|
|
backend=guided_decoding_backend))
|
|
|
|
outputs = llm.generate(
|
|
|
|
prompts="The best language for type-safe systems programming is ",
|
|
|
|
sampling_params=sampling_params,
|
|
|
|
use_tqdm=True)
|
|
|
|
|
|
|
|
assert outputs is not None
|
|
|
|
for output in outputs:
|
|
|
|
assert output is not None
|
|
|
|
assert isinstance(output, RequestOutput)
|
|
|
|
prompt = output.prompt
|
|
|
|
generated_text = output.outputs[0].text
|
|
|
|
print(generated_text)
|
|
|
|
assert generated_text is not None
|
|
|
|
assert generated_text in sample_guided_choice
|
|
|
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|