Support arbitrary json_object in OpenAI and Context Free Grammar (#3211)
This commit is contained in:
parent
8e67598aa6
commit
120157fd2a
@ -660,5 +660,55 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI):
|
||||
extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA))
|
||||
|
||||
|
||||
async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
|
||||
resp = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=[{
|
||||
"role":
|
||||
"user",
|
||||
"content": ('what is 1+1? please respond with a JSON object, '
|
||||
'the format is {"result": 2}')
|
||||
}],
|
||||
response_format={"type": "json_object"})
|
||||
|
||||
content = resp.choices[0].message.content
|
||||
loaded = json.loads(content)
|
||||
assert loaded == {"result": 2}, loaded
|
||||
|
||||
|
||||
async def test_guided_grammar(server, client: openai.AsyncOpenAI):
|
||||
simple_sql_grammar = """
|
||||
start: select_statement
|
||||
|
||||
select_statement: "SELECT" column "from" table "where" condition
|
||||
|
||||
column: "col_1" | "col_2"
|
||||
table: "table_1" | "table_2"
|
||||
condition: column "=" number
|
||||
|
||||
number: "1" | "2"
|
||||
"""
|
||||
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=("Generate a sql state that select col_1 from "
|
||||
"table_1 where it is equals to 1"),
|
||||
temperature=1.0,
|
||||
max_tokens=500,
|
||||
extra_body=dict(guided_grammar=simple_sql_grammar))
|
||||
|
||||
content = completion.choices[0].text
|
||||
|
||||
# use Lark to parse the output, and make sure it's a valid parse tree
|
||||
from lark import Lark
|
||||
parser = Lark(simple_sql_grammar)
|
||||
parser.parse(content)
|
||||
|
||||
# remove spaces for comparison b/c we removed them in the grammar
|
||||
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "")
|
||||
|
||||
assert content.strip() == ground_truth
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
@ -55,6 +55,11 @@ class UsageInfo(BaseModel):
|
||||
completion_tokens: Optional[int] = 0
|
||||
|
||||
|
||||
class ResponseFormat(BaseModel):
|
||||
# type must be "json_object" or "text"
|
||||
type: str = Literal["text", "json_object"]
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[Dict[str, str]]
|
||||
@ -89,6 +94,8 @@ class ChatCompletionRequest(BaseModel):
|
||||
guided_json: Optional[Union[str, dict, BaseModel]] = None
|
||||
guided_regex: Optional[str] = None
|
||||
guided_choice: Optional[List[str]] = None
|
||||
guided_grammar: Optional[str] = None
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
|
||||
def to_sampling_params(self) -> SamplingParams:
|
||||
if self.logprobs and not self.top_logprobs:
|
||||
@ -183,6 +190,8 @@ class CompletionRequest(BaseModel):
|
||||
guided_json: Optional[Union[str, dict, BaseModel]] = None
|
||||
guided_regex: Optional[str] = None
|
||||
guided_choice: Optional[List[str]] = None
|
||||
guided_grammar: Optional[str] = None
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
|
||||
def to_sampling_params(self):
|
||||
echo_without_generation = self.echo and self.max_tokens == 0
|
||||
|
@ -6,20 +6,51 @@ from functools import lru_cache
|
||||
from json import dumps as json_dumps
|
||||
from re import escape as regex_escape
|
||||
from typing import Union, Tuple
|
||||
|
||||
from pydantic import BaseModel
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (CompletionRequest,
|
||||
ChatCompletionRequest)
|
||||
from vllm.model_executor.guided_logits_processors import (JSONLogitsProcessor,
|
||||
RegexLogitsProcessor)
|
||||
RegexLogitsProcessor,
|
||||
CFGLogitsProcessor)
|
||||
|
||||
|
||||
class GuidedDecodingMode(Enum):
|
||||
JSON = "json"
|
||||
REGEX = "regex"
|
||||
CHOICE = "choice"
|
||||
GRAMMAR = "grammar"
|
||||
|
||||
|
||||
# https://github.com/outlines-dev/outlines/blob/main/outlines/grammars/json.lark
|
||||
# the main difference is that we changed the start: value to
|
||||
# start: object | array, so we are denying scalar values as the root of the
|
||||
# JSON. Starting with scalars as the root seems to cause llama to generate
|
||||
# without stop.
|
||||
JSON_GRAMMAR = r"""
|
||||
?start: object | array
|
||||
|
||||
?value: object
|
||||
| array
|
||||
| UNESCAPED_STRING
|
||||
| SIGNED_NUMBER -> number
|
||||
| "true" -> true
|
||||
| "false" -> false
|
||||
| "null" -> null
|
||||
|
||||
array : "[" [value ("," value)*] "]"
|
||||
object : "{" [pair ("," pair)*] "}"
|
||||
pair : UNESCAPED_STRING ":" value
|
||||
|
||||
%import common.UNESCAPED_STRING
|
||||
%import common.SIGNED_NUMBER
|
||||
%import common.WS
|
||||
|
||||
%ignore WS
|
||||
"""
|
||||
|
||||
global_thread_pool = None # used for generating logits processor fsm
|
||||
|
||||
|
||||
@ -57,9 +88,6 @@ def _get_guide_and_mode(
|
||||
) -> Tuple[str, GuidedDecodingMode]:
|
||||
|
||||
if request.guided_json:
|
||||
if not isinstance(request.guided_json, (str, dict, BaseModel)):
|
||||
raise TypeError("JSON schema must be str, dict, or BaseModel")
|
||||
|
||||
json = request.guided_json
|
||||
if isinstance(json, dict):
|
||||
# turn dict into hashable string
|
||||
@ -69,33 +97,33 @@ def _get_guide_and_mode(
|
||||
# with the same fields will get hashed the same
|
||||
json = str(json.__signature__)
|
||||
return json, GuidedDecodingMode.JSON
|
||||
|
||||
elif request.guided_regex:
|
||||
if not isinstance(request.guided_regex, str):
|
||||
raise TypeError("Regex must be string")
|
||||
return request.guided_regex, GuidedDecodingMode.REGEX
|
||||
|
||||
elif request.guided_choice:
|
||||
if not isinstance(request.guided_choice, list):
|
||||
raise TypeError("Choices must be a list")
|
||||
|
||||
# choice just uses regex
|
||||
choices = [
|
||||
regex_escape(str(choice)) for choice in request.guided_choice
|
||||
]
|
||||
choices_regex = "(" + "|".join(choices) + ")"
|
||||
return choices_regex, GuidedDecodingMode.CHOICE
|
||||
|
||||
elif request.guided_grammar:
|
||||
return request.guided_grammar, GuidedDecodingMode.GRAMMAR
|
||||
elif (request.response_format is not None
|
||||
and request.response_format.type == "json_object"):
|
||||
return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
|
||||
else:
|
||||
return None, None
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def _get_cached_logits_processor(guide: str, tokenizer,
|
||||
def _get_cached_logits_processor(guide: str,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
mode: GuidedDecodingMode):
|
||||
if mode == GuidedDecodingMode.JSON:
|
||||
return JSONLogitsProcessor(guide, tokenizer)
|
||||
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
|
||||
return RegexLogitsProcessor(guide, tokenizer)
|
||||
elif mode == GuidedDecodingMode.GRAMMAR:
|
||||
return CFGLogitsProcessor(guide, tokenizer)
|
||||
else:
|
||||
raise ValueError(f"Unknown guided decoding mode {mode}")
|
||||
|
@ -16,30 +16,60 @@
|
||||
import json
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from typing import Union, DefaultDict, Dict, List, Optional
|
||||
from typing import Union, DefaultDict, Dict, List, Optional, Callable
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
from outlines.fsm.fsm import RegexFSM
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
from outlines.fsm.fsm import RegexFSM, CFGFSM
|
||||
from outlines.fsm.json_schema import build_regex_from_schema
|
||||
|
||||
|
||||
class RegexLogitsProcessor:
|
||||
class BaseLogitsProcessor:
|
||||
|
||||
def __init__(self, regex_string: str, tokenizer):
|
||||
"""Compile the FSM that drives the regex-structured generation.
|
||||
def adapt_tokenizer(self, tokenizer: PreTrainedTokenizerBase):
|
||||
"""Adapt vLLM's tokenizer to use to compile the FSM.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
regex_string
|
||||
A string that represents a regular expression
|
||||
tokenizer
|
||||
The model's tokenizer
|
||||
The API of Outlines tokenizers is slightly different to that of
|
||||
`transformers`. The decoder of outlines, returns a list whereas
|
||||
the decode of vLLM returns an str. To sync the vLLM decoder with
|
||||
outlines internal api, the decoder should be adapted. In addition
|
||||
we need to handle the missing spaces to Llama's tokenizer to be
|
||||
able to compile FSMs for this model.
|
||||
|
||||
"""
|
||||
tokenizer = self.adapt_tokenizer(tokenizer)
|
||||
fsm = RegexFSM(regex_string, tokenizer)
|
||||
self.fsm = fsm
|
||||
if getattr(tokenizer, "_outlines_adapted", False):
|
||||
return tokenizer
|
||||
|
||||
tokenizer.vocabulary = tokenizer.get_vocab()
|
||||
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
|
||||
|
||||
def convert_token_to_string(token: str) -> str:
|
||||
from transformers.file_utils import SPIECE_UNDERLINE
|
||||
|
||||
string = tokenizer.convert_tokens_to_string([token])
|
||||
|
||||
# A hack to handle missing spaces to HF's Llama tokenizers
|
||||
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
||||
return " " + string
|
||||
|
||||
return string
|
||||
|
||||
def change_decoder(
|
||||
decoder: Callable[[List[int]], str]
|
||||
) -> Callable[[List[int]], List[str]]:
|
||||
"""Sync vLLM's decoder with the outlines by returning list."""
|
||||
|
||||
def new_decoder(inp_tokens: List[int]) -> List[str]:
|
||||
return [decoder(inp_tokens)]
|
||||
|
||||
return new_decoder
|
||||
|
||||
tokenizer.convert_token_to_string = convert_token_to_string
|
||||
tokenizer.decode = change_decoder(tokenizer.decode)
|
||||
setattr(tokenizer, "_outlines_adapted", True) # noqa: B010
|
||||
|
||||
return tokenizer
|
||||
|
||||
def init_state(self):
|
||||
"""Initialize the FSM states."""
|
||||
@ -69,38 +99,30 @@ class RegexLogitsProcessor:
|
||||
|
||||
return scores
|
||||
|
||||
def adapt_tokenizer(self, tokenizer):
|
||||
"""Adapt vLLM's tokenizer to use to compile the FSM.
|
||||
|
||||
The API of Outlines tokenizers is slightly different to that of
|
||||
`transformers`. In addition we need to handle the missing spaces to
|
||||
Llama's tokenizer to be able to compile FSMs for this model.
|
||||
class RegexLogitsProcessor(BaseLogitsProcessor):
|
||||
|
||||
def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):
|
||||
"""Compile the FSM that drives the regex-structured generation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
regex_string
|
||||
A string that represents a regular expression
|
||||
tokenizer
|
||||
The model's tokenizer
|
||||
|
||||
"""
|
||||
tokenizer.vocabulary = tokenizer.get_vocab()
|
||||
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
|
||||
|
||||
def convert_token_to_string(token: str) -> str:
|
||||
from transformers.file_utils import SPIECE_UNDERLINE
|
||||
|
||||
string = tokenizer.convert_tokens_to_string([token])
|
||||
|
||||
# A hack to handle missing spaces to HF's Llama tokenizers
|
||||
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
||||
return " " + string
|
||||
|
||||
return string
|
||||
|
||||
tokenizer.convert_token_to_string = convert_token_to_string
|
||||
|
||||
return tokenizer
|
||||
tokenizer = self.adapt_tokenizer(tokenizer)
|
||||
fsm = RegexFSM(regex_string, tokenizer)
|
||||
self.fsm = fsm
|
||||
|
||||
|
||||
class JSONLogitsProcessor(RegexLogitsProcessor):
|
||||
|
||||
def __init__(self,
|
||||
schema: Union[str, Dict, BaseModel],
|
||||
tokenizer,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
whitespace_pattern: Optional[str] = None):
|
||||
"""Compile the FSM that drives the JSON-guided generation.
|
||||
|
||||
@ -130,3 +152,21 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
|
||||
f"the JSON Schema specification")
|
||||
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
|
||||
super().__init__(regex_string, tokenizer)
|
||||
|
||||
|
||||
class CFGLogitsProcessor(BaseLogitsProcessor):
|
||||
|
||||
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase):
|
||||
"""Compile the FSM that drives the context free grammar generation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cfg
|
||||
A string that represents a context-free grammar
|
||||
tokenizer
|
||||
The model's tokenizer
|
||||
|
||||
"""
|
||||
tokenizer = self.adapt_tokenizer(tokenizer)
|
||||
fsm = CFGFSM(cfg, tokenizer)
|
||||
self.fsm = fsm
|
||||
|
Loading…
x
Reference in New Issue
Block a user