Russell Bryant 1f16b7fe74
[Core][V0] Add guidance backend for structured output (#14589)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Co-authored-by: Loc Huynh <lohuynh@microsoft.com>
Co-authored-by: Michal Moskal <michal@moskal.me>
Co-authored-by: Aaron Pham <contact@aarnphm.xyz>
2025-03-19 21:33:51 -07:00

179 lines
8.1 KiB
Python

# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from typing import TYPE_CHECKING
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.reasoner import get_reasoner
from vllm.model_executor.guided_decoding.utils import (
convert_lark_to_gbnf, grammar_is_likely_lark,
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.logits_process import LogitsProcessor
from vllm.sampling_params import GuidedDecodingParams
logger = init_logger(__name__)
def maybe_backend_fallback(
guided_params: GuidedDecodingParams) -> GuidedDecodingParams:
def fallback_or_error(guided_params: GuidedDecodingParams, message: str,
fallback: str) -> None:
"""Change the backend to the specified fallback with a warning log,
or raise a ValueError if the `no-fallback` option is specified."""
if guided_params.no_fallback():
raise ValueError(message)
logger.warning("%s Falling back to use %s instead.", message, fallback)
guided_params.backend = fallback
# lm-format-enforce doesn't support grammar, fallback to xgrammar
if guided_params.backend_name == "lm-format-enforcer":
if guided_params.grammar is not None:
fallback_or_error(
guided_params,
"lm-format-enforcer does not support grammar guided decoding.",
"xgrammar")
# lm-format-enforcer doesn't support some JSON schema features
elif (guided_params.json is not None
and has_lmf_unsupported_json_features(guided_params.json)):
fallback_or_error(
guided_params,
"lm-format-enforcer does not support advanced JSON schema "
"features like patterns or numeric ranges.", "outlines")
if guided_params.backend_name == "xgrammar":
from vllm.model_executor.guided_decoding.xgrammar_decoding import (
xgr_installed)
# xgrammar doesn't support regex, fallback to outlines
if guided_params.regex is not None:
fallback_or_error(
guided_params,
"xgrammar does not support regex guided decoding.", "outlines")
# xgrammar doesn't support some JSON schema features
elif (guided_params.json is not None
and has_xgrammar_unsupported_json_features(guided_params.json)):
fallback_or_error(
guided_params,
"xgrammar does not support advanced JSON schema features like "
"enums, patterns or numeric ranges.", "outlines")
# xgrammar only supports GBNF grammars, so we must convert Lark.
# We must check if the grammar is likely Lark and if that
# grammar is convertible to GBNF
elif (guided_params.grammar is not None
and grammar_is_likely_lark(guided_params.grammar)):
try:
convert_lark_to_gbnf(guided_params.grammar)
except Exception:
fallback_or_error(
guided_params,
"xgrammar does not support Lark grammars and the "
"grammar failed to convert to GBNF.", "outlines")
elif guided_params.json_object:
# https://github.com/mlc-ai/xgrammar/issues/256
fallback_or_error(guided_params,
"xgrammar does not support json_object.",
"guidance")
# If the xgrammar module cannot be imported successfully,
# we should still allow users to use guided decoding with a fallback.
elif not xgr_installed:
fallback_or_error(
guided_params,
"xgrammar module cannot be imported successfully.", "outlines")
if (guided_params.backend_name == "outlines"
and guided_params.json_object is not None):
# outlines doesn't support json_object, fallback to guidance
fallback_or_error(guided_params,
"outlines does not support json_object.", "guidance")
return guided_params
async def get_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer,
model_config: ModelConfig,
reasoning_backend: str | None = None) -> LogitsProcessor | None:
reasoner = get_reasoner(tokenizer, reasoning_backend)
guided_params = maybe_backend_fallback(guided_params)
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend_name == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_outlines_guided_decoding_logits_processor)
return await get_outlines_guided_decoding_logits_processor(
guided_params, tokenizer, reasoner)
if guided_params.backend == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_params, tokenizer)
if guided_params.backend_name == 'xgrammar':
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer, model_config, reasoner)
if guided_params.backend_name == 'guidance':
from vllm.model_executor.guided_decoding.guidance_decoding import (
get_local_guidance_guided_decoding_logits_processor)
return get_local_guidance_guided_decoding_logits_processor(
guided_params, tokenizer)
raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'"
)
def get_local_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer,
model_config: ModelConfig,
reasoning_backend: str | None = None) -> LogitsProcessor | None:
guided_params = maybe_backend_fallback(guided_params)
# Get the reasoner if needed, it will be None if reasoning_
reasoner = get_reasoner(tokenizer, reasoning_backend)
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend_name == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_local_outlines_guided_decoding_logits_processor)
return get_local_outlines_guided_decoding_logits_processor(
guided_params, tokenizer, reasoner)
if guided_params.backend_name == 'lm-format-enforcer':
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import ( # noqa
get_local_lm_format_enforcer_guided_decoding_logits_processor)
return get_local_lm_format_enforcer_guided_decoding_logits_processor(
guided_params, tokenizer)
if guided_params.backend_name == 'xgrammar':
from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer, model_config, reasoner)
if guided_params.backend_name == 'guidance':
from vllm.model_executor.guided_decoding.guidance_decoding import (
get_local_guidance_guided_decoding_logits_processor)
return get_local_guidance_guided_decoding_logits_processor(
guided_params, tokenizer)
raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'"
)