[V1] guidance backend for structured output + auto
fallback mode (#14779)
Signed-off-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Loc Huynh <jc1da.3011@gmail.com> Co-authored-by: Michal Moskal <michal@moskal.me>
This commit is contained in:
parent
10b34e36b9
commit
a09ad90a72
@ -18,7 +18,7 @@ pillow # Required for image processing
|
|||||||
prometheus-fastapi-instrumentator >= 7.0.0
|
prometheus-fastapi-instrumentator >= 7.0.0
|
||||||
tiktoken >= 0.6.0 # Required for DBRX tokenizer
|
tiktoken >= 0.6.0 # Required for DBRX tokenizer
|
||||||
lm-format-enforcer >= 0.10.11, < 0.11
|
lm-format-enforcer >= 0.10.11, < 0.11
|
||||||
llguidance >= 0.7.2, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
|
llguidance >= 0.7.9, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
|
||||||
outlines == 0.1.11
|
outlines == 0.1.11
|
||||||
lark == 1.2.2
|
lark == 1.2.2
|
||||||
xgrammar == 0.1.16; platform_machine == "x86_64" or platform_machine == "aarch64"
|
xgrammar == 0.1.16; platform_machine == "x86_64" or platform_machine == "aarch64"
|
||||||
|
@ -13,7 +13,7 @@ from vllm.entrypoints.llm import LLM
|
|||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||||
|
|
||||||
GUIDED_DECODING_BACKENDS_V1 = ["xgrammar"]
|
GUIDED_DECODING_BACKENDS_V1 = ["xgrammar", "guidance"]
|
||||||
MODELS_TO_TEST = [
|
MODELS_TO_TEST = [
|
||||||
"Qwen/Qwen2.5-1.5B-Instruct", "mistralai/Ministral-8B-Instruct-2410"
|
"Qwen/Qwen2.5-1.5B-Instruct", "mistralai/Ministral-8B-Instruct-2410"
|
||||||
]
|
]
|
||||||
@ -30,12 +30,13 @@ def test_guided_json_completion(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
):
|
):
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
llm = LLM(model=model_name, max_model_len=1024)
|
llm = LLM(model=model_name,
|
||||||
sampling_params = SamplingParams(temperature=1.0,
|
max_model_len=1024,
|
||||||
|
guided_decoding_backend=guided_decoding_backend)
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=1.0,
|
||||||
max_tokens=1000,
|
max_tokens=1000,
|
||||||
guided_decoding=GuidedDecodingParams(
|
guided_decoding=GuidedDecodingParams(json=sample_json_schema))
|
||||||
json=sample_json_schema,
|
|
||||||
backend=guided_decoding_backend))
|
|
||||||
outputs = llm.generate(prompts=[
|
outputs = llm.generate(prompts=[
|
||||||
f"Give an example JSON for an employee profile "
|
f"Give an example JSON for an employee profile "
|
||||||
f"that fits this schema: {sample_json_schema}"
|
f"that fits this schema: {sample_json_schema}"
|
||||||
@ -111,13 +112,14 @@ def test_guided_json_object(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
):
|
):
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
llm = LLM(model=model_name, max_model_len=1024)
|
llm = LLM(model=model_name,
|
||||||
sampling_params = SamplingParams(temperature=1.0,
|
max_model_len=1024,
|
||||||
|
guided_decoding_backend=guided_decoding_backend)
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=1.0,
|
||||||
max_tokens=100,
|
max_tokens=100,
|
||||||
n=2,
|
n=2,
|
||||||
guided_decoding=GuidedDecodingParams(
|
guided_decoding=GuidedDecodingParams(json_object=True))
|
||||||
json_object=True,
|
|
||||||
backend=guided_decoding_backend))
|
|
||||||
|
|
||||||
outputs = llm.generate(
|
outputs = llm.generate(
|
||||||
prompts=("Generate a JSON object with curly braces for a person with "
|
prompts=("Generate a JSON object with curly braces for a person with "
|
||||||
@ -137,12 +139,20 @@ def test_guided_json_object(
|
|||||||
|
|
||||||
# Parse to verify it is valid JSON
|
# Parse to verify it is valid JSON
|
||||||
parsed_json = json.loads(generated_text)
|
parsed_json = json.loads(generated_text)
|
||||||
assert isinstance(parsed_json, dict)
|
allowed_types: tuple[type, ...] = (dict, )
|
||||||
|
if guided_decoding_backend == "xgrammar":
|
||||||
|
# TODO - we are currently too permissive with xgrammar and
|
||||||
|
# allow # any valid json (typically comes back as a list or
|
||||||
|
# object). We can fix this by specifying a jsonschema of
|
||||||
|
# {"type": "object"}, # but we need this fix in a release
|
||||||
|
# first: https://github.com/mlc-ai/xgrammar/pull/264
|
||||||
|
allowed_types = (dict, list)
|
||||||
|
assert isinstance(parsed_json, allowed_types)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
@pytest.mark.parametrize("guided_decoding_backend",
|
@pytest.mark.parametrize("guided_decoding_backend",
|
||||||
GUIDED_DECODING_BACKENDS_V1)
|
GUIDED_DECODING_BACKENDS_V1 + ["auto"])
|
||||||
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
|
@pytest.mark.parametrize("model_name", MODELS_TO_TEST)
|
||||||
def test_guided_json_unsupported_schema(
|
def test_guided_json_unsupported_schema(
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
@ -151,12 +161,14 @@ def test_guided_json_unsupported_schema(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
):
|
):
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
llm = LLM(model=model_name, max_model_len=1024)
|
llm = LLM(model=model_name,
|
||||||
sampling_params = SamplingParams(temperature=1.0,
|
max_model_len=1024,
|
||||||
|
guided_decoding_backend=guided_decoding_backend)
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=1.0,
|
||||||
max_tokens=1000,
|
max_tokens=1000,
|
||||||
guided_decoding=GuidedDecodingParams(
|
guided_decoding=GuidedDecodingParams(json=unsupported_json_schema))
|
||||||
json=unsupported_json_schema,
|
if guided_decoding_backend == "xgrammar":
|
||||||
backend=guided_decoding_backend))
|
|
||||||
with pytest.raises(ValueError,
|
with pytest.raises(ValueError,
|
||||||
match="The provided JSON schema contains features "
|
match="The provided JSON schema contains features "
|
||||||
"not supported by xgrammar."):
|
"not supported by xgrammar."):
|
||||||
@ -166,6 +178,26 @@ def test_guided_json_unsupported_schema(
|
|||||||
] * 2,
|
] * 2,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
use_tqdm=True)
|
use_tqdm=True)
|
||||||
|
else:
|
||||||
|
# This should work for both "guidance" and "auto".
|
||||||
|
|
||||||
|
outputs = llm.generate(
|
||||||
|
prompts=("Give an example JSON object for a grade "
|
||||||
|
"that fits this schema: "
|
||||||
|
f"{unsupported_json_schema}"),
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
use_tqdm=True)
|
||||||
|
assert outputs is not None
|
||||||
|
for output in outputs:
|
||||||
|
assert output is not None
|
||||||
|
assert isinstance(output, RequestOutput)
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
assert generated_text is not None
|
||||||
|
print(generated_text)
|
||||||
|
|
||||||
|
# Parse to verify it is valid JSON
|
||||||
|
parsed_json = json.loads(generated_text)
|
||||||
|
assert isinstance(parsed_json, dict)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip_global_cleanup
|
@pytest.mark.skip_global_cleanup
|
||||||
@ -179,13 +211,14 @@ def test_guided_grammar_ebnf(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
):
|
):
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
llm = LLM(model=model_name, max_model_len=1024)
|
llm = LLM(model=model_name,
|
||||||
sampling_params = SamplingParams(temperature=0.8,
|
max_model_len=1024,
|
||||||
|
guided_decoding_backend=guided_decoding_backend)
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
max_tokens=1000,
|
max_tokens=1000,
|
||||||
guided_decoding=GuidedDecodingParams(
|
guided_decoding=GuidedDecodingParams(grammar=sample_sql_ebnf))
|
||||||
grammar=sample_sql_ebnf,
|
|
||||||
backend=guided_decoding_backend))
|
|
||||||
outputs = llm.generate(
|
outputs = llm.generate(
|
||||||
prompts=("Generate a sql statement that selects col_1 from "
|
prompts=("Generate a sql statement that selects col_1 from "
|
||||||
"table_1 where it is equal to 1"),
|
"table_1 where it is equal to 1"),
|
||||||
@ -222,13 +255,14 @@ def test_guided_grammar_lark(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
):
|
):
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
llm = LLM(model=model_name, max_model_len=1024)
|
llm = LLM(model=model_name,
|
||||||
sampling_params = SamplingParams(temperature=0.8,
|
max_model_len=1024,
|
||||||
|
guided_decoding_backend=guided_decoding_backend)
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
max_tokens=1000,
|
max_tokens=1000,
|
||||||
guided_decoding=GuidedDecodingParams(
|
guided_decoding=GuidedDecodingParams(grammar=sample_sql_lark))
|
||||||
grammar=sample_sql_lark,
|
|
||||||
backend=guided_decoding_backend))
|
|
||||||
outputs = llm.generate(
|
outputs = llm.generate(
|
||||||
prompts=("Generate a sql statement that selects col_1 from "
|
prompts=("Generate a sql statement that selects col_1 from "
|
||||||
"table_1 where it is equal to 1"),
|
"table_1 where it is equal to 1"),
|
||||||
@ -269,16 +303,15 @@ def test_guided_grammar_ebnf_invalid(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
):
|
):
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
llm = LLM(model=model_name, max_model_len=1024)
|
llm = LLM(model=model_name,
|
||||||
sampling_params = SamplingParams(temperature=0.8,
|
max_model_len=1024,
|
||||||
|
guided_decoding_backend=guided_decoding_backend)
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
max_tokens=1000,
|
max_tokens=1000,
|
||||||
guided_decoding=GuidedDecodingParams(
|
guided_decoding=GuidedDecodingParams(grammar="not a grammar"))
|
||||||
grammar="not a grammar",
|
with pytest.raises(ValueError, match="Failed to convert the grammar "):
|
||||||
backend=guided_decoding_backend))
|
|
||||||
with pytest.raises(ValueError,
|
|
||||||
match="Failed to convert the grammar "
|
|
||||||
"from Lark to EBNF."):
|
|
||||||
llm.generate(
|
llm.generate(
|
||||||
prompts=("Generate a sql statement that selects col_1 from "
|
prompts=("Generate a sql statement that selects col_1 from "
|
||||||
"table_1 where it is equal to 1"),
|
"table_1 where it is equal to 1"),
|
||||||
@ -298,12 +331,13 @@ def test_guided_regex(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
):
|
):
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
llm = LLM(model=model_name, max_model_len=1024)
|
llm = LLM(model=model_name,
|
||||||
sampling_params = SamplingParams(temperature=0.8,
|
max_model_len=1024,
|
||||||
|
guided_decoding_backend=guided_decoding_backend)
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
guided_decoding=GuidedDecodingParams(
|
guided_decoding=GuidedDecodingParams(regex=sample_regex))
|
||||||
regex=sample_regex,
|
|
||||||
backend=guided_decoding_backend))
|
|
||||||
outputs = llm.generate(
|
outputs = llm.generate(
|
||||||
prompts=[
|
prompts=[
|
||||||
f"Give an example IPv4 address with this regex: {sample_regex}"
|
f"Give an example IPv4 address with this regex: {sample_regex}"
|
||||||
@ -335,12 +369,13 @@ def test_guided_choice_completion(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
):
|
):
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
llm = LLM(model=model_name, max_model_len=1024)
|
llm = LLM(model=model_name,
|
||||||
sampling_params = SamplingParams(temperature=0.8,
|
max_model_len=1024,
|
||||||
|
guided_decoding_backend=guided_decoding_backend)
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0.8,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
guided_decoding=GuidedDecodingParams(
|
guided_decoding=GuidedDecodingParams(choice=sample_guided_choice))
|
||||||
choice=sample_guided_choice,
|
|
||||||
backend=guided_decoding_backend))
|
|
||||||
outputs = llm.generate(
|
outputs = llm.generate(
|
||||||
prompts="The best language for type-safe systems programming is ",
|
prompts="The best language for type-safe systems programming is ",
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
|
@ -2800,12 +2800,17 @@ class DecodingConfig:
|
|||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
valid_guided_backends = [
|
v0_valid_guided_backends = [
|
||||||
'outlines', 'lm-format-enforcer', 'xgrammar', 'guidance'
|
'outlines', 'lm-format-enforcer', 'xgrammar'
|
||||||
]
|
]
|
||||||
|
v1_valid_guided_backends = ['xgrammar', 'guidance', 'auto']
|
||||||
|
|
||||||
backend = GuidedDecodingParams(
|
backend = GuidedDecodingParams(
|
||||||
backend=self.guided_decoding_backend).backend_name
|
backend=self.guided_decoding_backend).backend_name
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
valid_guided_backends = v1_valid_guided_backends
|
||||||
|
else:
|
||||||
|
valid_guided_backends = v0_valid_guided_backends
|
||||||
if backend not in valid_guided_backends:
|
if backend not in valid_guided_backends:
|
||||||
raise ValueError(f"Invalid guided_decoding_backend '{backend}',"
|
raise ValueError(f"Invalid guided_decoding_backend '{backend}',"
|
||||||
f" must be one of {valid_guided_backends}")
|
f" must be one of {valid_guided_backends}")
|
||||||
|
@ -391,16 +391,13 @@ class EngineArgs:
|
|||||||
default='xgrammar',
|
default='xgrammar',
|
||||||
help='Which engine will be used for guided decoding'
|
help='Which engine will be used for guided decoding'
|
||||||
' (JSON schema / regex etc) by default. Currently support '
|
' (JSON schema / regex etc) by default. Currently support '
|
||||||
'https://github.com/outlines-dev/outlines, '
|
'https://github.com/mlc-ai/xgrammar and '
|
||||||
'https://github.com/mlc-ai/xgrammar, and '
|
'https://github.com/guidance-ai/llguidance.'
|
||||||
'https://github.com/noamgat/lm-format-enforcer.'
|
'Valid backend values are "xgrammar", "guidance", and "auto". '
|
||||||
' Can be overridden per request via guided_decoding_backend'
|
'With "auto", we will make opinionated choices based on request'
|
||||||
' parameter.\n'
|
'contents and what the backend libraries currently support, so '
|
||||||
'Backend-specific options can be supplied in a comma-separated '
|
'the behavior is subject to change in each release. '
|
||||||
'list following a colon after the backend name. Valid backends and '
|
'The default is xgrammar.')
|
||||||
'all available options are: [xgrammar:no-fallback, '
|
|
||||||
'xgrammar:disable-any-whitespace, '
|
|
||||||
'outlines:no-fallback, lm-format-enforcer:no-fallback]')
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--logits-processor-pattern',
|
'--logits-processor-pattern',
|
||||||
type=nullable_str,
|
type=nullable_str,
|
||||||
@ -1539,9 +1536,9 @@ class EngineArgs:
|
|||||||
recommend_to_remove=False)
|
recommend_to_remove=False)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Only support Xgrammar for guided decoding so far.
|
# Xgrammar and Guidance are supported.
|
||||||
SUPPORTED_GUIDED_DECODING = [
|
SUPPORTED_GUIDED_DECODING = [
|
||||||
"xgrammar", "xgrammar:disable-any-whitespace"
|
"xgrammar", "xgrammar:disable-any-whitespace", "guidance", "auto"
|
||||||
]
|
]
|
||||||
if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING:
|
if self.guided_decoding_backend not in SUPPORTED_GUIDED_DECODING:
|
||||||
_raise_or_fallback(feature_name="--guided-decoding-backend",
|
_raise_or_fallback(feature_name="--guided-decoding-backend",
|
||||||
|
@ -4,7 +4,6 @@ import time
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import vllm.platforms
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
||||||
PromptType, SingletonInputsAdapter)
|
PromptType, SingletonInputsAdapter)
|
||||||
@ -20,7 +19,10 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
|
|||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||||
from vllm.v1.engine import EngineCoreRequest
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
from vllm.v1.structured_output.utils import validate_structured_output_request
|
from vllm.v1.structured_output.backend_guidance import (
|
||||||
|
validate_guidance_grammar)
|
||||||
|
from vllm.v1.structured_output.utils import (
|
||||||
|
validate_structured_output_request_xgrammar)
|
||||||
|
|
||||||
|
|
||||||
class Processor:
|
class Processor:
|
||||||
@ -120,7 +122,9 @@ class Processor:
|
|||||||
if not params.guided_decoding or not self.decoding_config:
|
if not params.guided_decoding or not self.decoding_config:
|
||||||
return
|
return
|
||||||
|
|
||||||
supported_backends = ["xgrammar", "xgrammar:disable-any-whitespace"]
|
supported_backends = [
|
||||||
|
"xgrammar", "xgrammar:disable-any-whitespace", "guidance", "auto"
|
||||||
|
]
|
||||||
engine_level_backend = self.decoding_config.guided_decoding_backend
|
engine_level_backend = self.decoding_config.guided_decoding_backend
|
||||||
if engine_level_backend not in supported_backends:
|
if engine_level_backend not in supported_backends:
|
||||||
raise ValueError(f"Only {supported_backends} structured output is "
|
raise ValueError(f"Only {supported_backends} structured output is "
|
||||||
@ -134,10 +138,31 @@ class Processor:
|
|||||||
else:
|
else:
|
||||||
params.guided_decoding.backend = engine_level_backend
|
params.guided_decoding.backend = engine_level_backend
|
||||||
|
|
||||||
if vllm.platforms.current_platform.is_tpu():
|
# Request content validation
|
||||||
raise ValueError("Structured output is not supported on TPU.")
|
|
||||||
|
|
||||||
validate_structured_output_request(params)
|
if engine_level_backend == "xgrammar":
|
||||||
|
# xgrammar with no fallback
|
||||||
|
validate_structured_output_request_xgrammar(params)
|
||||||
|
params.guided_decoding.backend = "xgrammar"
|
||||||
|
elif engine_level_backend == "auto":
|
||||||
|
# "auto" is an opt-in to opinionated behavior where we try to
|
||||||
|
# choose a backend based on request contents. This is not the
|
||||||
|
# default as it is less predictable and subject to change
|
||||||
|
# between releases as feature support changes.
|
||||||
|
try:
|
||||||
|
validate_structured_output_request_xgrammar(params)
|
||||||
|
params.guided_decoding.backend = "xgrammar"
|
||||||
|
except ValueError:
|
||||||
|
# The request includes some jsonschema feature(s) that
|
||||||
|
# are not supported in xgrammar. Fall back to guidance.
|
||||||
|
params.guided_decoding.backend = "guidance"
|
||||||
|
|
||||||
|
if params.guided_decoding.backend == "guidance":
|
||||||
|
# TODO ideally we would have the LLTokenizer here as Lark syntax
|
||||||
|
# allows <|special_token|> and similar, see
|
||||||
|
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
|
||||||
|
# Without tokenizer these are disallowed in grammars.
|
||||||
|
validate_guidance_grammar(params, tokenizer=None)
|
||||||
|
|
||||||
def process_inputs(
|
def process_inputs(
|
||||||
self,
|
self,
|
||||||
|
@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Optional
|
|||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.v1.structured_output.backend_guidance import GuidanceBackend
|
||||||
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
|
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
|
||||||
StructuredOutputGrammar)
|
StructuredOutputGrammar)
|
||||||
|
|
||||||
@ -50,6 +51,8 @@ class StructuredOutputManager:
|
|||||||
XgrammarBackend)
|
XgrammarBackend)
|
||||||
|
|
||||||
self.backend = XgrammarBackend(self.vllm_config)
|
self.backend = XgrammarBackend(self.vllm_config)
|
||||||
|
elif backend_name == "guidance":
|
||||||
|
self.backend = GuidanceBackend(self.vllm_config)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported structured output backend: {backend_name}")
|
f"Unsupported structured output backend: {backend_name}")
|
||||||
|
164
vllm/v1/structured_output/backend_guidance.py
Normal file
164
vllm/v1/structured_output/backend_guidance.py
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||||
|
from vllm.utils import LazyLoader
|
||||||
|
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
|
||||||
|
StructuredOutputGrammar,
|
||||||
|
StructuredOutputOptions)
|
||||||
|
from vllm.v1.structured_output.request import get_structured_output_key
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import llguidance
|
||||||
|
import llguidance.hf as llguidance_hf
|
||||||
|
import llguidance.torch as llguidance_torch
|
||||||
|
else:
|
||||||
|
llguidance = LazyLoader("llguidance", globals(), "llguidance")
|
||||||
|
llguidance_hf = LazyLoader("llguidance.hf", globals(), "llguidance.hf")
|
||||||
|
llguidance_torch = LazyLoader("llguidance.torch", globals(),
|
||||||
|
"llguidance.torch")
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class GuidanceBackend(StructuredOutputBackend):
|
||||||
|
|
||||||
|
def __init__(self, vllm_config: VllmConfig):
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
tokenizer_group = init_tokenizer_from_configs(
|
||||||
|
model_config=vllm_config.model_config,
|
||||||
|
scheduler_config=vllm_config.scheduler_config,
|
||||||
|
parallel_config=vllm_config.parallel_config,
|
||||||
|
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
|
||||||
|
tokenizer_group.ping()
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
self.vocab_size = vllm_config.model_config.get_vocab_size()
|
||||||
|
|
||||||
|
tokenizer = tokenizer_group.get_lora_tokenizer(None)
|
||||||
|
self.ll_tokenizer = llguidance_hf.from_tokenizer(tokenizer, None)
|
||||||
|
|
||||||
|
def compile_grammar(self, request_type: StructuredOutputOptions,
|
||||||
|
grammar_spec: str) -> StructuredOutputGrammar:
|
||||||
|
self.serialized_grammar = serialize_guidance_grammar(
|
||||||
|
request_type, grammar_spec)
|
||||||
|
|
||||||
|
ll_matcher = llguidance.LLMatcher(
|
||||||
|
self.ll_tokenizer,
|
||||||
|
self.serialized_grammar,
|
||||||
|
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
|
||||||
|
)
|
||||||
|
|
||||||
|
r = GuidanceGrammar(
|
||||||
|
ll_matcher=ll_matcher,
|
||||||
|
ll_tokenizer=self.ll_tokenizer,
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
r.check_error()
|
||||||
|
return r
|
||||||
|
|
||||||
|
def allocate_token_bitmask(self, max_num_seqs: int):
|
||||||
|
return llguidance_torch.allocate_token_bitmask(
|
||||||
|
max_num_seqs, self.ll_tokenizer.vocab_size)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GuidanceGrammar(StructuredOutputGrammar):
|
||||||
|
ll_matcher: llguidance.LLMatcher
|
||||||
|
ll_tokenizer: llguidance.LLTokenizer
|
||||||
|
vocab_size: int
|
||||||
|
printed_error: bool = False
|
||||||
|
terminated: bool = False
|
||||||
|
|
||||||
|
def check_error(self):
|
||||||
|
if not self.printed_error:
|
||||||
|
err = self.ll_matcher.get_error()
|
||||||
|
if err:
|
||||||
|
self.printed_error = True
|
||||||
|
logger.warning("LLMatcher error: %s", err)
|
||||||
|
|
||||||
|
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
|
||||||
|
"""Accepts a list of tokens and advances the parser.
|
||||||
|
|
||||||
|
Returns True if the parser was advanced successfully.
|
||||||
|
Returns False if the parser failed to advance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.ll_tokenizer.eos_token in tokens:
|
||||||
|
self.terminated = True
|
||||||
|
|
||||||
|
if self.ll_matcher.is_stopped():
|
||||||
|
return True
|
||||||
|
|
||||||
|
# TODO - Add jump decoding support in the future:
|
||||||
|
# self.ll_matcher.compute_ff_bytes() - this should always work
|
||||||
|
# self.ll_matcher.compute_ff_tokens() - this only works for
|
||||||
|
# "canonical" tokenizers
|
||||||
|
# For conversion between the two, see
|
||||||
|
# https://github.com/guidance-ai/llguidance/blob/main/docs/fast_forward.md
|
||||||
|
|
||||||
|
r = self.ll_matcher.consume_tokens(tokens)
|
||||||
|
|
||||||
|
self.check_error()
|
||||||
|
|
||||||
|
return r
|
||||||
|
|
||||||
|
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
|
||||||
|
# this will automatically return [EOS] mask if the matcher is stopped
|
||||||
|
# or otherwise in an error state
|
||||||
|
llguidance_torch.fill_next_token_bitmask(self.ll_matcher, bitmask, idx)
|
||||||
|
self.check_error()
|
||||||
|
|
||||||
|
def is_terminated(self) -> bool:
|
||||||
|
return self.terminated
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
# This method may be not needed anymore? TODO
|
||||||
|
self.ll_matcher.reset()
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_guidance_grammar(request_type: StructuredOutputOptions,
|
||||||
|
grammar_spec: str) -> str:
|
||||||
|
if request_type == StructuredOutputOptions.JSON:
|
||||||
|
# TODO: make whitespace_flexible configurable
|
||||||
|
return llguidance.LLMatcher.grammar_from_json_schema(
|
||||||
|
grammar_spec, defaults={
|
||||||
|
"whitespace_flexible": True,
|
||||||
|
})
|
||||||
|
elif request_type == StructuredOutputOptions.JSON_OBJECT:
|
||||||
|
return llguidance.LLMatcher.grammar_from_json_schema(
|
||||||
|
'{"type": "object"}', defaults={
|
||||||
|
"whitespace_flexible": True,
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
if request_type == StructuredOutputOptions.REGEX:
|
||||||
|
tp = "regex"
|
||||||
|
elif request_type == StructuredOutputOptions.GRAMMAR:
|
||||||
|
tp = "grammar"
|
||||||
|
elif request_type == StructuredOutputOptions.CHOICE:
|
||||||
|
tp = "choice"
|
||||||
|
else:
|
||||||
|
logger.error("Validation should have already occurred. "
|
||||||
|
"Please file an issue.")
|
||||||
|
raise ValueError("grammar is not of valid supported types. "
|
||||||
|
f"({request_type!s})")
|
||||||
|
return llguidance.grammar_from(tp, grammar_spec)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_guidance_grammar(
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
tokenizer: Optional[llguidance.LLTokenizer] = None) -> None:
|
||||||
|
tp, grm = get_structured_output_key(sampling_params)
|
||||||
|
guidance_grm = serialize_guidance_grammar(tp, grm)
|
||||||
|
err = llguidance.LLMatcher.validate_grammar(guidance_grm,
|
||||||
|
tokenizer=tokenizer)
|
||||||
|
if err:
|
||||||
|
raise ValueError(f"Grammar error: {err}")
|
@ -53,7 +53,12 @@ class StructuredOutputRequest:
|
|||||||
|
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def structured_output_key(self) -> StructuredOutputKey:
|
def structured_output_key(self) -> StructuredOutputKey:
|
||||||
params = self.sampling_params.guided_decoding
|
return get_structured_output_key(self.sampling_params)
|
||||||
|
|
||||||
|
|
||||||
|
def get_structured_output_key(
|
||||||
|
sampling_params: SamplingParams) -> StructuredOutputKey:
|
||||||
|
params = sampling_params.guided_decoding
|
||||||
assert params is not None, "params can't be None."
|
assert params is not None, "params can't be None."
|
||||||
if params.json is not None:
|
if params.json is not None:
|
||||||
if not isinstance(params.json, str):
|
if not isinstance(params.json, str):
|
||||||
|
@ -239,7 +239,7 @@ def choice_as_grammar(choice: list[str]) -> str:
|
|||||||
return grammar
|
return grammar
|
||||||
|
|
||||||
|
|
||||||
def validate_structured_output_request(
|
def validate_structured_output_request_xgrammar(
|
||||||
sampling_params: SamplingParams) -> None:
|
sampling_params: SamplingParams) -> None:
|
||||||
"""Validate that the request is supported by structured output.
|
"""Validate that the request is supported by structured output.
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user