vllm/vllm/v1/structured_output/backend_guidance.py
Benjamin Chislett 3147586ebd
[Bugfix] Fix guidance backend for Qwen models (#16210)
Signed-off-by: Benjamin Chislett <benjamin.chislett@centml.ai>
2025-04-07 22:15:43 +00:00

170 lines
6.2 KiB
Python

# SPDX-License-Identifier: Apache-2.0
import os
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import LazyLoader
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
StructuredOutputGrammar,
StructuredOutputOptions)
from vllm.v1.structured_output.request import get_structured_output_key
if TYPE_CHECKING:
import llguidance
import llguidance.hf as llguidance_hf
import llguidance.torch as llguidance_torch
else:
llguidance = LazyLoader("llguidance", globals(), "llguidance")
llguidance_hf = LazyLoader("llguidance.hf", globals(), "llguidance.hf")
llguidance_torch = LazyLoader("llguidance.torch", globals(),
"llguidance.torch")
logger = init_logger(__name__)
class GuidanceBackend(StructuredOutputBackend):
def __init__(self, vllm_config: VllmConfig):
self.vllm_config = vllm_config
tokenizer_group = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
parallel_config=vllm_config.parallel_config,
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
tokenizer_group.ping()
self.vllm_config = vllm_config
self.vocab_size = vllm_config.model_config.get_vocab_size()
self.disable_any_whitespace = (
"disable-any-whitespace"
in vllm_config.decoding_config.guided_decoding_backend)
tokenizer = tokenizer_group.get_lora_tokenizer(None)
self.ll_tokenizer = llguidance_hf.from_tokenizer(
tokenizer, self.vocab_size)
def compile_grammar(self, request_type: StructuredOutputOptions,
grammar_spec: str) -> StructuredOutputGrammar:
self.serialized_grammar = serialize_guidance_grammar(
request_type, grammar_spec, self.disable_any_whitespace)
ll_matcher = llguidance.LLMatcher(
self.ll_tokenizer,
self.serialized_grammar,
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
)
r = GuidanceGrammar(
ll_matcher=ll_matcher,
ll_tokenizer=self.ll_tokenizer,
vocab_size=self.vocab_size,
)
r.check_error()
return r
def allocate_token_bitmask(self, max_num_seqs: int):
return llguidance_torch.allocate_token_bitmask(
max_num_seqs, self.ll_tokenizer.vocab_size)
@dataclass
class GuidanceGrammar(StructuredOutputGrammar):
ll_matcher: llguidance.LLMatcher
ll_tokenizer: llguidance.LLTokenizer
vocab_size: int
printed_error: bool = False
terminated: bool = False
def check_error(self):
if not self.printed_error:
err = self.ll_matcher.get_error()
if err:
self.printed_error = True
logger.warning("LLMatcher error: %s", err)
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
"""Accepts a list of tokens and advances the parser.
Returns True if the parser was advanced successfully.
Returns False if the parser failed to advance.
"""
if self.ll_tokenizer.eos_token in tokens:
self.terminated = True
if self.ll_matcher.is_stopped():
return True
# TODO - Add jump decoding support in the future:
# self.ll_matcher.compute_ff_bytes() - this should always work
# self.ll_matcher.compute_ff_tokens() - this only works for
# "canonical" tokenizers
# For conversion between the two, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/fast_forward.md
r = self.ll_matcher.consume_tokens(tokens)
self.check_error()
return r
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
# this will automatically return [EOS] mask if the matcher is stopped
# or otherwise in an error state
llguidance_torch.fill_next_token_bitmask(self.ll_matcher, bitmask, idx)
self.check_error()
def is_terminated(self) -> bool:
return self.terminated
def reset(self):
# This method may be not needed anymore? TODO
self.ll_matcher.reset()
def serialize_guidance_grammar(request_type: StructuredOutputOptions,
grammar_spec: str,
disable_any_whitespace: bool = False) -> str:
if request_type == StructuredOutputOptions.JSON:
return llguidance.LLMatcher.grammar_from_json_schema(
grammar_spec,
defaults={
"whitespace_flexible": not disable_any_whitespace,
})
elif request_type == StructuredOutputOptions.JSON_OBJECT:
return llguidance.LLMatcher.grammar_from_json_schema(
'{"type": "object"}',
defaults={
"whitespace_flexible": not disable_any_whitespace,
})
else:
if request_type == StructuredOutputOptions.REGEX:
tp = "regex"
elif request_type == StructuredOutputOptions.GRAMMAR:
tp = "grammar"
elif request_type == StructuredOutputOptions.CHOICE:
tp = "choice"
else:
logger.error("Validation should have already occurred. "
"Please file an issue.")
raise ValueError("grammar is not of valid supported types. "
f"({request_type!s})")
return llguidance.grammar_from(tp, grammar_spec)
def validate_guidance_grammar(
sampling_params: SamplingParams,
tokenizer: Optional[llguidance.LLTokenizer] = None) -> None:
tp, grm = get_structured_output_key(sampling_params)
guidance_grm = serialize_guidance_grammar(tp, grm)
err = llguidance.LLMatcher.validate_grammar(guidance_grm, tokenizer)
if err:
raise ValueError(f"Grammar error: {err}")