[Core][V0] Add guidance backend for structured output (#14589)
Signed-off-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Loc Huynh <lohuynh@microsoft.com> Co-authored-by: Michal Moskal <michal@moskal.me> Co-authored-by: Aaron Pham <contact@aarnphm.xyz>
This commit is contained in:
parent
b88be22165
commit
1f16b7fe74
@ -999,11 +999,12 @@ if __name__ == "__main__":
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Ratio of Structured Outputs requests")
|
||||
parser.add_argument("--structured-output-backend",
|
||||
type=str,
|
||||
choices=["outlines", "lm-format-enforcer", "xgrammar"],
|
||||
default="xgrammar",
|
||||
help="Backend to use for structured outputs")
|
||||
parser.add_argument(
|
||||
"--structured-output-backend",
|
||||
type=str,
|
||||
choices=["outlines", "lm-format-enforcer", "xgrammar", "guidance"],
|
||||
default="xgrammar",
|
||||
help="Backend to use for structured outputs")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
@ -18,6 +18,7 @@ pillow # Required for image processing
|
||||
prometheus-fastapi-instrumentator >= 7.0.0
|
||||
tiktoken >= 0.6.0 # Required for DBRX tokenizer
|
||||
lm-format-enforcer >= 0.10.11, < 0.11
|
||||
llguidance >= 0.7.2, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
|
||||
outlines == 0.1.11
|
||||
lark == 1.2.2
|
||||
xgrammar == 0.1.16; platform_machine == "x86_64" or platform_machine == "aarch64"
|
||||
|
@ -14,7 +14,9 @@ from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
|
||||
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
|
||||
GUIDED_DECODING_BACKENDS = [
|
||||
"outlines", "lm-format-enforcer", "xgrammar", "guidance"
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
@ -16,7 +16,9 @@ from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta'
|
||||
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
|
||||
GUIDED_DECODING_BACKENDS = [
|
||||
"outlines", "lm-format-enforcer", "xgrammar", "guidance"
|
||||
]
|
||||
GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT = ["outlines", "xgrammar"]
|
||||
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
|
||||
|
||||
|
@ -2785,7 +2785,9 @@ class DecodingConfig:
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
valid_guided_backends = ['outlines', 'lm-format-enforcer', 'xgrammar']
|
||||
valid_guided_backends = [
|
||||
'outlines', 'lm-format-enforcer', 'xgrammar', 'guidance'
|
||||
]
|
||||
|
||||
backend = GuidedDecodingParams(
|
||||
backend=self.guided_decoding_backend).backend_name
|
||||
|
@ -79,6 +79,12 @@ def maybe_backend_fallback(
|
||||
"xgrammar does not support Lark grammars and the "
|
||||
"grammar failed to convert to GBNF.", "outlines")
|
||||
|
||||
elif guided_params.json_object:
|
||||
# https://github.com/mlc-ai/xgrammar/issues/256
|
||||
fallback_or_error(guided_params,
|
||||
"xgrammar does not support json_object.",
|
||||
"guidance")
|
||||
|
||||
# If the xgrammar module cannot be imported successfully,
|
||||
# we should still allow users to use guided decoding with a fallback.
|
||||
elif not xgr_installed:
|
||||
@ -88,9 +94,9 @@ def maybe_backend_fallback(
|
||||
|
||||
if (guided_params.backend_name == "outlines"
|
||||
and guided_params.json_object is not None):
|
||||
# outlines doesn't support json_object, fallback to xgrammar
|
||||
# outlines doesn't support json_object, fallback to guidance
|
||||
fallback_or_error(guided_params,
|
||||
"outlines does not support json_object.", "xgrammar")
|
||||
"outlines does not support json_object.", "guidance")
|
||||
|
||||
return guided_params
|
||||
|
||||
@ -122,10 +128,15 @@ async def get_guided_decoding_logits_processor(
|
||||
get_local_xgrammar_guided_decoding_logits_processor)
|
||||
return get_local_xgrammar_guided_decoding_logits_processor(
|
||||
guided_params, tokenizer, model_config, reasoner)
|
||||
|
||||
if guided_params.backend_name == 'guidance':
|
||||
from vllm.model_executor.guided_decoding.guidance_decoding import (
|
||||
get_local_guidance_guided_decoding_logits_processor)
|
||||
return get_local_guidance_guided_decoding_logits_processor(
|
||||
guided_params, tokenizer)
|
||||
raise ValueError(
|
||||
f"Unknown guided decoding backend '{guided_params.backend}'. "
|
||||
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")
|
||||
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'"
|
||||
)
|
||||
|
||||
|
||||
def get_local_guided_decoding_logits_processor(
|
||||
@ -155,7 +166,13 @@ def get_local_guided_decoding_logits_processor(
|
||||
get_local_xgrammar_guided_decoding_logits_processor)
|
||||
return get_local_xgrammar_guided_decoding_logits_processor(
|
||||
guided_params, tokenizer, model_config, reasoner)
|
||||
if guided_params.backend_name == 'guidance':
|
||||
from vllm.model_executor.guided_decoding.guidance_decoding import (
|
||||
get_local_guidance_guided_decoding_logits_processor)
|
||||
return get_local_guidance_guided_decoding_logits_processor(
|
||||
guided_params, tokenizer)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown guided decoding backend '{guided_params.backend}'. "
|
||||
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")
|
||||
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'"
|
||||
)
|
||||
|
44
vllm/model_executor/guided_decoding/guidance_decoding.py
Normal file
44
vllm/model_executor/guided_decoding/guidance_decoding.py
Normal file
@ -0,0 +1,44 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from re import escape as regex_escape
|
||||
|
||||
import llguidance
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.model_executor.guided_decoding.guidance_logits_processors import (
|
||||
GuidanceLogitsProcessor)
|
||||
from vllm.sampling_params import GuidedDecodingParams
|
||||
|
||||
|
||||
def get_local_guidance_guided_decoding_logits_processor(
|
||||
guided_params: GuidedDecodingParams,
|
||||
tokenizer: PreTrainedTokenizerBase) -> GuidanceLogitsProcessor:
|
||||
"""
|
||||
Given an OpenAI-compatible request, check for guided decoding parameters
|
||||
and get the necessary logits processor for the given guide.
|
||||
"""
|
||||
|
||||
grm = ""
|
||||
if guided_params.json:
|
||||
grm = llguidance.LLMatcher.grammar_from_json_schema(
|
||||
guided_params.json,
|
||||
overrides={"whitespace_pattern": guided_params.whitespace_pattern})
|
||||
elif guided_params.json_object:
|
||||
grm = llguidance.LLMatcher.grammar_from_json_schema(
|
||||
'{"type": "object"}',
|
||||
overrides={"whitespace_pattern": guided_params.whitespace_pattern})
|
||||
elif guided_params.regex:
|
||||
grm = llguidance.grammar_from("regex", guided_params.regex)
|
||||
elif guided_params.choice:
|
||||
# choice just uses regex
|
||||
choices = (regex_escape(str(choice))
|
||||
for choice in guided_params.choice)
|
||||
choices_regex = "(" + "|".join(choices) + ")"
|
||||
grm = llguidance.grammar_from("regex", choices_regex)
|
||||
elif guided_params.grammar:
|
||||
# this supports Lark and GBNF
|
||||
grm = llguidance.grammar_from("grammar", guided_params.grammar)
|
||||
|
||||
if grm:
|
||||
return GuidanceLogitsProcessor(grm, tokenizer)
|
||||
|
||||
raise ValueError("Unknown guided decoding mode")
|
@ -0,0 +1,85 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import os
|
||||
from typing import Any, List
|
||||
|
||||
import llguidance
|
||||
import llguidance.hf
|
||||
import llguidance.torch
|
||||
import torch
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GuidanceLogitsProcessor:
|
||||
"""Base Guidance Logits Processor"""
|
||||
|
||||
cached_tokenizers: dict[str, Any] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
grammar: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> None:
|
||||
"""Base Guidance Logits Processor
|
||||
|
||||
Args:
|
||||
grammar (str)
|
||||
grammar to guide the generation
|
||||
tokenizer (PreTrainedTokenizerBase)
|
||||
model's tokenizer
|
||||
"""
|
||||
self.grammar = grammar
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer_name = tokenizer.name_or_path
|
||||
self.new_sampling = False
|
||||
self.initialized = False
|
||||
|
||||
def _initialize(self):
|
||||
if self.initialized:
|
||||
return
|
||||
|
||||
ll_tokenizer = self.cached_tokenizers.get(self.tokenizer.name_or_path,
|
||||
None)
|
||||
if ll_tokenizer is None:
|
||||
ll_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None)
|
||||
self.cached_tokenizers[self.tokenizer.name_or_path] = ll_tokenizer
|
||||
|
||||
self.ll_tokenizer = ll_tokenizer
|
||||
self.ll_matcher = llguidance.LLMatcher(
|
||||
self.ll_tokenizer,
|
||||
self.grammar,
|
||||
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
|
||||
)
|
||||
|
||||
# create reusable bitmask
|
||||
self.bitmask = llguidance.torch.allocate_token_bitmask(
|
||||
1, self.ll_tokenizer.vocab_size)
|
||||
|
||||
self.initialized = True
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: List[int],
|
||||
scores: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# we initialize the guidance model here
|
||||
# to avoid pickling ll_tokenizer and ll_interpreter
|
||||
self._initialize()
|
||||
|
||||
if self.new_sampling and len(input_ids) > 0:
|
||||
self.ll_matcher.consume_token(input_ids[-1])
|
||||
err = self.ll_matcher.get_error()
|
||||
if err:
|
||||
logger.warning("Error in LLMatcher: %s", err)
|
||||
|
||||
llguidance.torch.fill_next_token_bitmask(self.ll_matcher, self.bitmask,
|
||||
0)
|
||||
llguidance.torch.apply_token_bitmask_inplace(
|
||||
scores, self.bitmask.to(scores.device))
|
||||
|
||||
self.new_sampling = True
|
||||
|
||||
return scores
|
Loading…
x
Reference in New Issue
Block a user