170 lines
6.2 KiB
Python
170 lines
6.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import os
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Optional
|
|
|
|
import torch
|
|
|
|
from vllm.config import VllmConfig
|
|
from vllm.logger import init_logger
|
|
from vllm.sampling_params import SamplingParams
|
|
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
|
from vllm.utils import LazyLoader
|
|
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
|
|
StructuredOutputGrammar,
|
|
StructuredOutputOptions)
|
|
from vllm.v1.structured_output.request import get_structured_output_key
|
|
|
|
if TYPE_CHECKING:
|
|
import llguidance
|
|
import llguidance.hf as llguidance_hf
|
|
import llguidance.torch as llguidance_torch
|
|
else:
|
|
llguidance = LazyLoader("llguidance", globals(), "llguidance")
|
|
llguidance_hf = LazyLoader("llguidance.hf", globals(), "llguidance.hf")
|
|
llguidance_torch = LazyLoader("llguidance.torch", globals(),
|
|
"llguidance.torch")
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class GuidanceBackend(StructuredOutputBackend):
|
|
|
|
def __init__(self, vllm_config: VllmConfig):
|
|
self.vllm_config = vllm_config
|
|
tokenizer_group = init_tokenizer_from_configs(
|
|
model_config=vllm_config.model_config,
|
|
scheduler_config=vllm_config.scheduler_config,
|
|
parallel_config=vllm_config.parallel_config,
|
|
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
|
|
tokenizer_group.ping()
|
|
self.vllm_config = vllm_config
|
|
self.vocab_size = vllm_config.model_config.get_vocab_size()
|
|
self.disable_any_whitespace = (
|
|
"disable-any-whitespace"
|
|
in vllm_config.decoding_config.guided_decoding_backend)
|
|
|
|
tokenizer = tokenizer_group.get_lora_tokenizer(None)
|
|
self.ll_tokenizer = llguidance_hf.from_tokenizer(
|
|
tokenizer, self.vocab_size)
|
|
|
|
def compile_grammar(self, request_type: StructuredOutputOptions,
|
|
grammar_spec: str) -> StructuredOutputGrammar:
|
|
self.serialized_grammar = serialize_guidance_grammar(
|
|
request_type, grammar_spec, self.disable_any_whitespace)
|
|
|
|
ll_matcher = llguidance.LLMatcher(
|
|
self.ll_tokenizer,
|
|
self.serialized_grammar,
|
|
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
|
|
)
|
|
|
|
r = GuidanceGrammar(
|
|
ll_matcher=ll_matcher,
|
|
ll_tokenizer=self.ll_tokenizer,
|
|
vocab_size=self.vocab_size,
|
|
)
|
|
|
|
r.check_error()
|
|
return r
|
|
|
|
def allocate_token_bitmask(self, max_num_seqs: int):
|
|
return llguidance_torch.allocate_token_bitmask(
|
|
max_num_seqs, self.ll_tokenizer.vocab_size)
|
|
|
|
|
|
@dataclass
|
|
class GuidanceGrammar(StructuredOutputGrammar):
|
|
ll_matcher: llguidance.LLMatcher
|
|
ll_tokenizer: llguidance.LLTokenizer
|
|
vocab_size: int
|
|
printed_error: bool = False
|
|
terminated: bool = False
|
|
|
|
def check_error(self):
|
|
if not self.printed_error:
|
|
err = self.ll_matcher.get_error()
|
|
if err:
|
|
self.printed_error = True
|
|
logger.warning("LLMatcher error: %s", err)
|
|
|
|
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
|
|
"""Accepts a list of tokens and advances the parser.
|
|
|
|
Returns True if the parser was advanced successfully.
|
|
Returns False if the parser failed to advance.
|
|
"""
|
|
|
|
if self.ll_tokenizer.eos_token in tokens:
|
|
self.terminated = True
|
|
|
|
if self.ll_matcher.is_stopped():
|
|
return True
|
|
|
|
# TODO - Add jump decoding support in the future:
|
|
# self.ll_matcher.compute_ff_bytes() - this should always work
|
|
# self.ll_matcher.compute_ff_tokens() - this only works for
|
|
# "canonical" tokenizers
|
|
# For conversion between the two, see
|
|
# https://github.com/guidance-ai/llguidance/blob/main/docs/fast_forward.md
|
|
|
|
r = self.ll_matcher.consume_tokens(tokens)
|
|
|
|
self.check_error()
|
|
|
|
return r
|
|
|
|
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
|
|
# this will automatically return [EOS] mask if the matcher is stopped
|
|
# or otherwise in an error state
|
|
llguidance_torch.fill_next_token_bitmask(self.ll_matcher, bitmask, idx)
|
|
self.check_error()
|
|
|
|
def is_terminated(self) -> bool:
|
|
return self.terminated
|
|
|
|
def reset(self):
|
|
# This method may be not needed anymore? TODO
|
|
self.ll_matcher.reset()
|
|
|
|
|
|
def serialize_guidance_grammar(request_type: StructuredOutputOptions,
|
|
grammar_spec: str,
|
|
disable_any_whitespace: bool = False) -> str:
|
|
if request_type == StructuredOutputOptions.JSON:
|
|
return llguidance.LLMatcher.grammar_from_json_schema(
|
|
grammar_spec,
|
|
defaults={
|
|
"whitespace_flexible": not disable_any_whitespace,
|
|
})
|
|
elif request_type == StructuredOutputOptions.JSON_OBJECT:
|
|
return llguidance.LLMatcher.grammar_from_json_schema(
|
|
'{"type": "object"}',
|
|
defaults={
|
|
"whitespace_flexible": not disable_any_whitespace,
|
|
})
|
|
else:
|
|
if request_type == StructuredOutputOptions.REGEX:
|
|
tp = "regex"
|
|
elif request_type == StructuredOutputOptions.GRAMMAR:
|
|
tp = "grammar"
|
|
elif request_type == StructuredOutputOptions.CHOICE:
|
|
tp = "choice"
|
|
else:
|
|
logger.error("Validation should have already occurred. "
|
|
"Please file an issue.")
|
|
raise ValueError("grammar is not of valid supported types. "
|
|
f"({request_type!s})")
|
|
return llguidance.grammar_from(tp, grammar_spec)
|
|
|
|
|
|
def validate_guidance_grammar(
|
|
sampling_params: SamplingParams,
|
|
tokenizer: Optional[llguidance.LLTokenizer] = None) -> None:
|
|
tp, grm = get_structured_output_key(sampling_params)
|
|
guidance_grm = serialize_guidance_grammar(tp, grm)
|
|
err = llguidance.LLMatcher.validate_grammar(guidance_grm, tokenizer)
|
|
if err:
|
|
raise ValueError(f"Grammar error: {err}")
|