[Core][V0] Add guidance backend for structured output (#14589)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Loc Huynh <lohuynh@microsoft.com>
Co-authored-by: Michal Moskal <michal@moskal.me>
Co-authored-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
Russell Bryant 2025-03-20 00:33:51 -04:00 committed by GitHub
parent b88be22165
commit 1f16b7fe74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 167 additions and 13 deletions

View File

@ -999,11 +999,12 @@ if __name__ == "__main__":
type=float,
default=1.0,
help="Ratio of Structured Outputs requests")
parser.add_argument("--structured-output-backend",
type=str,
choices=["outlines", "lm-format-enforcer", "xgrammar"],
default="xgrammar",
help="Backend to use for structured outputs")
parser.add_argument(
"--structured-output-backend",
type=str,
choices=["outlines", "lm-format-enforcer", "xgrammar", "guidance"],
default="xgrammar",
help="Backend to use for structured outputs")
args = parser.parse_args()
main(args)

View File

@ -18,6 +18,7 @@ pillow # Required for image processing
prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer >= 0.10.11, < 0.11
llguidance >= 0.7.2, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
outlines == 0.1.11
lark == 1.2.2
xgrammar == 0.1.16; platform_machine == "x86_64" or platform_machine == "aarch64"

View File

@ -14,7 +14,9 @@ from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
GUIDED_DECODING_BACKENDS = [
"outlines", "lm-format-enforcer", "xgrammar", "guidance"
]
@pytest.fixture(scope="module")

View File

@ -16,7 +16,9 @@ from vllm.model_executor.guided_decoding.outlines_logits_processors import (
from vllm.sampling_params import GuidedDecodingParams
MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta'
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
GUIDED_DECODING_BACKENDS = [
"outlines", "lm-format-enforcer", "xgrammar", "guidance"
]
GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT = ["outlines", "xgrammar"]
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"

View File

@ -2785,7 +2785,9 @@ class DecodingConfig:
return hash_str
def __post_init__(self):
valid_guided_backends = ['outlines', 'lm-format-enforcer', 'xgrammar']
valid_guided_backends = [
'outlines', 'lm-format-enforcer', 'xgrammar', 'guidance'
]
backend = GuidedDecodingParams(
backend=self.guided_decoding_backend).backend_name

View File

@ -79,6 +79,12 @@ def maybe_backend_fallback(
"xgrammar does not support Lark grammars and the "
"grammar failed to convert to GBNF.", "outlines")
elif guided_params.json_object:
# https://github.com/mlc-ai/xgrammar/issues/256
fallback_or_error(guided_params,
"xgrammar does not support json_object.",
"guidance")
# If the xgrammar module cannot be imported successfully,
# we should still allow users to use guided decoding with a fallback.
elif not xgr_installed:
@ -88,9 +94,9 @@ def maybe_backend_fallback(
if (guided_params.backend_name == "outlines"
and guided_params.json_object is not None):
# outlines doesn't support json_object, fallback to xgrammar
# outlines doesn't support json_object, fallback to guidance
fallback_or_error(guided_params,
"outlines does not support json_object.", "xgrammar")
"outlines does not support json_object.", "guidance")
return guided_params
@ -122,10 +128,15 @@ async def get_guided_decoding_logits_processor(
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer, model_config, reasoner)
if guided_params.backend_name == 'guidance':
from vllm.model_executor.guided_decoding.guidance_decoding import (
get_local_guidance_guided_decoding_logits_processor)
return get_local_guidance_guided_decoding_logits_processor(
guided_params, tokenizer)
raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'"
)
def get_local_guided_decoding_logits_processor(
@ -155,7 +166,13 @@ def get_local_guided_decoding_logits_processor(
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer, model_config, reasoner)
if guided_params.backend_name == 'guidance':
from vllm.model_executor.guided_decoding.guidance_decoding import (
get_local_guidance_guided_decoding_logits_processor)
return get_local_guidance_guided_decoding_logits_processor(
guided_params, tokenizer)
raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'"
)

View File

@ -0,0 +1,44 @@
# SPDX-License-Identifier: Apache-2.0
from re import escape as regex_escape
import llguidance
from transformers import PreTrainedTokenizerBase
from vllm.model_executor.guided_decoding.guidance_logits_processors import (
GuidanceLogitsProcessor)
from vllm.sampling_params import GuidedDecodingParams
def get_local_guidance_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizerBase) -> GuidanceLogitsProcessor:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
"""
grm = ""
if guided_params.json:
grm = llguidance.LLMatcher.grammar_from_json_schema(
guided_params.json,
overrides={"whitespace_pattern": guided_params.whitespace_pattern})
elif guided_params.json_object:
grm = llguidance.LLMatcher.grammar_from_json_schema(
'{"type": "object"}',
overrides={"whitespace_pattern": guided_params.whitespace_pattern})
elif guided_params.regex:
grm = llguidance.grammar_from("regex", guided_params.regex)
elif guided_params.choice:
# choice just uses regex
choices = (regex_escape(str(choice))
for choice in guided_params.choice)
choices_regex = "(" + "|".join(choices) + ")"
grm = llguidance.grammar_from("regex", choices_regex)
elif guided_params.grammar:
# this supports Lark and GBNF
grm = llguidance.grammar_from("grammar", guided_params.grammar)
if grm:
return GuidanceLogitsProcessor(grm, tokenizer)
raise ValueError("Unknown guided decoding mode")

View File

@ -0,0 +1,85 @@
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Any, List
import llguidance
import llguidance.hf
import llguidance.torch
import torch
from transformers import PreTrainedTokenizerBase
from vllm.logger import init_logger
logger = init_logger(__name__)
class GuidanceLogitsProcessor:
"""Base Guidance Logits Processor"""
cached_tokenizers: dict[str, Any] = {}
def __init__(
self,
grammar: str,
tokenizer: PreTrainedTokenizerBase,
) -> None:
"""Base Guidance Logits Processor
Args:
grammar (str)
grammar to guide the generation
tokenizer (PreTrainedTokenizerBase)
model's tokenizer
"""
self.grammar = grammar
self.tokenizer = tokenizer
self.tokenizer_name = tokenizer.name_or_path
self.new_sampling = False
self.initialized = False
def _initialize(self):
if self.initialized:
return
ll_tokenizer = self.cached_tokenizers.get(self.tokenizer.name_or_path,
None)
if ll_tokenizer is None:
ll_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None)
self.cached_tokenizers[self.tokenizer.name_or_path] = ll_tokenizer
self.ll_tokenizer = ll_tokenizer
self.ll_matcher = llguidance.LLMatcher(
self.ll_tokenizer,
self.grammar,
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
)
# create reusable bitmask
self.bitmask = llguidance.torch.allocate_token_bitmask(
1, self.ll_tokenizer.vocab_size)
self.initialized = True
def __call__(
self,
input_ids: List[int],
scores: torch.Tensor,
) -> torch.Tensor:
# we initialize the guidance model here
# to avoid pickling ll_tokenizer and ll_interpreter
self._initialize()
if self.new_sampling and len(input_ids) > 0:
self.ll_matcher.consume_token(input_ids[-1])
err = self.ll_matcher.get_error()
if err:
logger.warning("Error in LLMatcher: %s", err)
llguidance.torch.fill_next_token_bitmask(self.ll_matcher, self.bitmask,
0)
llguidance.torch.apply_token_bitmask_inplace(
scores, self.bitmask.to(scores.device))
self.new_sampling = True
return scores