LM Format Enforcer Guided Decoding Support (#3868)
Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
parent
4e7ee664e2
commit
05434764cd
@ -11,6 +11,7 @@ uvicorn[standard]
|
||||
pydantic >= 2.0 # Required for OpenAI server.
|
||||
prometheus_client >= 0.18.0
|
||||
tiktoken == 0.6.0 # Required for DBRX tokenizer
|
||||
lm-format-enforcer == 0.9.3
|
||||
outlines == 0.0.34 # Requires torch >= 2.1.0
|
||||
typing_extensions
|
||||
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
|
||||
|
@ -1,11 +1,14 @@
|
||||
# This unit test should be moved to a new
|
||||
# tests/test_guided_decoding directory.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.model_executor.guided_logits_processors import (JSONLogitsProcessor,
|
||||
RegexLogitsProcessor)
|
||||
from vllm.entrypoints.openai.protocol import CompletionRequest
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
get_guided_decoding_logits_processor)
|
||||
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
||||
JSONLogitsProcessor, RegexLogitsProcessor)
|
||||
|
||||
TEST_SCHEMA = {
|
||||
"type": "object",
|
||||
@ -73,3 +76,36 @@ def test_guided_logits_processors():
|
||||
json_LP(token_ids, tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert not torch.allclose(tensor, original_tensor)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_logits_processor_black_box(backend: str):
|
||||
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
|
||||
token_ids = tokenizer.encode(
|
||||
f"Give an example IPv4 address with this regex: {TEST_REGEX}")
|
||||
regex_request = CompletionRequest(model='test',
|
||||
prompt=token_ids,
|
||||
guided_regex=TEST_REGEX)
|
||||
regex_lp = await get_guided_decoding_logits_processor(
|
||||
backend, regex_request, tokenizer)
|
||||
assert regex_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
tensor = regex_lp(token_ids, tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert not torch.allclose(tensor, original_tensor)
|
||||
|
||||
token_ids = tokenizer.encode(
|
||||
f"Give an employee profile that fits this schema: {TEST_SCHEMA}")
|
||||
json_request = CompletionRequest(model='test',
|
||||
prompt=token_ids,
|
||||
guided_json=TEST_SCHEMA)
|
||||
json_lp = await get_guided_decoding_logits_processor(
|
||||
backend, json_request, tokenizer)
|
||||
assert json_lp is not None
|
||||
tensor = torch.rand(32000)
|
||||
original_tensor = torch.clone(tensor)
|
||||
tensor = json_lp(token_ids, tensor)
|
||||
assert tensor.shape == original_tensor.shape
|
||||
assert not torch.allclose(tensor, original_tensor)
|
||||
|
@ -506,7 +506,10 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
|
||||
assert first_response != completion.choices[0].text
|
||||
|
||||
|
||||
async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_json_completion(server, client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str):
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=f"Give an example JSON for an employee profile "
|
||||
@ -514,7 +517,8 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
|
||||
n=3,
|
||||
temperature=1.0,
|
||||
max_tokens=500,
|
||||
extra_body=dict(guided_json=TEST_SCHEMA))
|
||||
extra_body=dict(guided_json=TEST_SCHEMA,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
|
||||
assert completion.id is not None
|
||||
assert completion.choices is not None and len(completion.choices) == 3
|
||||
@ -524,7 +528,10 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
|
||||
jsonschema.validate(instance=output_json, schema=TEST_SCHEMA)
|
||||
|
||||
|
||||
async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_json_chat(server, client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
@ -538,8 +545,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=500,
|
||||
extra_body=dict(guided_json=TEST_SCHEMA))
|
||||
max_tokens=1000,
|
||||
extra_body=dict(guided_json=TEST_SCHEMA,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
message = chat_completion.choices[0].message
|
||||
assert message.content is not None
|
||||
json1 = json.loads(message.content)
|
||||
@ -555,8 +563,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
|
||||
chat_completion = await client.chat.completions.create(
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=500,
|
||||
extra_body=dict(guided_json=TEST_SCHEMA))
|
||||
max_tokens=1000,
|
||||
extra_body=dict(guided_json=TEST_SCHEMA,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
message = chat_completion.choices[0].message
|
||||
assert message.content is not None
|
||||
json2 = json.loads(message.content)
|
||||
@ -565,14 +574,18 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
|
||||
assert json1["age"] != json2["age"]
|
||||
|
||||
|
||||
async def test_guided_regex_completion(server, client: openai.AsyncOpenAI):
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_regex_completion(server, client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str):
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=f"Give an example IPv4 address with this regex: {TEST_REGEX}",
|
||||
n=3,
|
||||
temperature=1.0,
|
||||
max_tokens=20,
|
||||
extra_body=dict(guided_regex=TEST_REGEX))
|
||||
extra_body=dict(guided_regex=TEST_REGEX,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
|
||||
assert completion.id is not None
|
||||
assert completion.choices is not None and len(completion.choices) == 3
|
||||
@ -581,7 +594,10 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI):
|
||||
assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None
|
||||
|
||||
|
||||
async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_regex_chat(server, client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
@ -595,7 +611,8 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=20,
|
||||
extra_body=dict(guided_regex=TEST_REGEX))
|
||||
extra_body=dict(guided_regex=TEST_REGEX,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
ip1 = chat_completion.choices[0].message.content
|
||||
assert ip1 is not None
|
||||
assert re.fullmatch(TEST_REGEX, ip1) is not None
|
||||
@ -606,21 +623,26 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=20,
|
||||
extra_body=dict(guided_regex=TEST_REGEX))
|
||||
extra_body=dict(guided_regex=TEST_REGEX,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
ip2 = chat_completion.choices[0].message.content
|
||||
assert ip2 is not None
|
||||
assert re.fullmatch(TEST_REGEX, ip2) is not None
|
||||
assert ip1 != ip2
|
||||
|
||||
|
||||
async def test_guided_choice_completion(server, client: openai.AsyncOpenAI):
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_choice_completion(server, client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str):
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt="The best language for type-safe systems programming is ",
|
||||
n=2,
|
||||
temperature=1.0,
|
||||
max_tokens=10,
|
||||
extra_body=dict(guided_choice=TEST_CHOICE))
|
||||
extra_body=dict(guided_choice=TEST_CHOICE,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
|
||||
assert completion.id is not None
|
||||
assert completion.choices is not None and len(completion.choices) == 2
|
||||
@ -628,7 +650,10 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI):
|
||||
assert completion.choices[i].text in TEST_CHOICE
|
||||
|
||||
|
||||
async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_choice_chat(server, client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
@ -642,7 +667,8 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
extra_body=dict(guided_choice=TEST_CHOICE))
|
||||
extra_body=dict(guided_choice=TEST_CHOICE,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
choice1 = chat_completion.choices[0].message.content
|
||||
assert choice1 in TEST_CHOICE
|
||||
|
||||
@ -655,18 +681,23 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
|
||||
model=MODEL_NAME,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
extra_body=dict(guided_choice=TEST_CHOICE))
|
||||
extra_body=dict(guided_choice=TEST_CHOICE,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
choice2 = chat_completion.choices[0].message.content
|
||||
assert choice2 in TEST_CHOICE
|
||||
assert choice1 != choice2
|
||||
|
||||
|
||||
async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI):
|
||||
@pytest.mark.parametrize("guided_decoding_backend",
|
||||
["outlines", "lm-format-enforcer"])
|
||||
async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI,
|
||||
guided_decoding_backend: str):
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
_ = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt="Give an example JSON that fits this schema: 42",
|
||||
extra_body=dict(guided_json=42))
|
||||
extra_body=dict(guided_json=42,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
|
||||
messages = [{
|
||||
"role": "system",
|
||||
|
@ -66,8 +66,8 @@ class ModelConfig:
|
||||
weights. If None, we assume the model weights are not quantized.
|
||||
quantization_param_path: Path to JSON file containing scaling factors.
|
||||
Used to load KV cache scaling factors into the model when KV cache
|
||||
type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
|
||||
be used to load activation and weight scaling factors when the
|
||||
type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also
|
||||
be used to load activation and weight scaling factors when the
|
||||
model dtype is FP8_E4M3 on ROCm.
|
||||
enforce_eager: Whether to enforce eager execution. If True, we will
|
||||
disable CUDA graph and always execute the model in eager mode.
|
||||
@ -422,7 +422,7 @@ class CacheConfig:
|
||||
@dataclass
|
||||
class TokenizerPoolConfig:
|
||||
"""Configuration for the tokenizer pool.
|
||||
|
||||
|
||||
Args:
|
||||
pool_size: Number of tokenizer workers in the pool.
|
||||
pool_type: Type of the pool.
|
||||
@ -446,9 +446,9 @@ class TokenizerPoolConfig:
|
||||
tokenizer_pool_extra_config: Optional[Union[str, dict]]
|
||||
) -> Optional["TokenizerPoolConfig"]:
|
||||
"""Create a TokenizerPoolConfig from the given parameters.
|
||||
|
||||
|
||||
If tokenizer_pool_size is 0, return None.
|
||||
|
||||
|
||||
Args:
|
||||
tokenizer_pool_size: Number of tokenizer workers in the pool.
|
||||
tokenizer_pool_type: Type of the pool.
|
||||
@ -1079,6 +1079,21 @@ def _get_and_verify_max_len(
|
||||
return int(max_model_len)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecodingConfig:
|
||||
"""Dataclass which contains the decoding strategy of the engine"""
|
||||
|
||||
# Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer'
|
||||
guided_decoding_backend: str = 'outlines'
|
||||
|
||||
def __post_init__(self):
|
||||
valid_guided_backends = ['outlines', 'lm-format-enforcer']
|
||||
backend = self.guided_decoding_backend
|
||||
if backend not in valid_guided_backends:
|
||||
raise ValueError(f"Invalid guided_decoding_backend '{backend},"
|
||||
f"must be one of {valid_guided_backends}")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EngineConfig:
|
||||
"""Dataclass which contains all engine-related configuration. This
|
||||
@ -1093,6 +1108,7 @@ class EngineConfig:
|
||||
lora_config: Optional[LoRAConfig]
|
||||
vision_language_config: Optional[VisionLanguageConfig]
|
||||
speculative_config: Optional[SpeculativeConfig]
|
||||
decoding_config: Optional[DecodingConfig]
|
||||
tensorizer_config: Optional[TensorizerConfig]
|
||||
|
||||
def __post_init__(self):
|
||||
|
@ -5,9 +5,9 @@ import os
|
||||
from dataclasses import dataclass
|
||||
from typing import BinaryIO, Optional, Union
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, EngineConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
SpeculativeConfig, TensorizerConfig,
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
||||
EngineConfig, LoRAConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig, SpeculativeConfig, TensorizerConfig,
|
||||
TokenizerPoolConfig, VisionLanguageConfig)
|
||||
from vllm.model_executor.tensorizer_loader import TensorizerArgs
|
||||
from vllm.utils import str_to_int_tuple
|
||||
@ -80,6 +80,7 @@ class EngineArgs:
|
||||
scheduler_delay_factor: float = 0.0
|
||||
enable_chunked_prefill: bool = False
|
||||
|
||||
guided_decoding_backend: str = 'outlines'
|
||||
# Speculative decoding configuration.
|
||||
speculative_model: Optional[str] = None
|
||||
num_speculative_tokens: Optional[int] = None
|
||||
@ -200,6 +201,13 @@ class EngineArgs:
|
||||
default=EngineArgs.max_model_len,
|
||||
help='model context length. If unspecified, '
|
||||
'will be automatically derived from the model.')
|
||||
parser.add_argument(
|
||||
'--guided-decoding-backend',
|
||||
type=str,
|
||||
default='outlines',
|
||||
choices=['outlines', 'lm-format-enforcer'],
|
||||
help='Which engine will be used for guided decoding'
|
||||
' (JSON schema / regex etc)')
|
||||
# Parallel arguments
|
||||
parser.add_argument('--worker-use-ray',
|
||||
action='store_true',
|
||||
@ -511,6 +519,9 @@ class EngineArgs:
|
||||
else:
|
||||
vision_language_config = None
|
||||
|
||||
decoding_config = DecodingConfig(
|
||||
guided_decoding_backend=self.guided_decoding_backend)
|
||||
|
||||
return EngineConfig(model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
parallel_config=parallel_config,
|
||||
@ -519,6 +530,7 @@ class EngineArgs:
|
||||
lora_config=lora_config,
|
||||
vision_language_config=vision_language_config,
|
||||
speculative_config=speculative_config,
|
||||
decoding_config=decoding_config,
|
||||
tensorizer_config=tensorizer_config)
|
||||
|
||||
|
||||
|
@ -4,9 +4,10 @@ from typing import Iterable, List, Optional, Tuple, Type, Union
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
import vllm
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||
TensorizerConfig, VisionLanguageConfig)
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
SpeculativeConfig, TensorizerConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.metrics import StatLogger, Stats
|
||||
@ -74,6 +75,7 @@ class LLMEngine:
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
decoding_config: Optional[DecodingConfig],
|
||||
tensorizer_config: Optional[TensorizerConfig],
|
||||
executor_class: Type[ExecutorBase],
|
||||
log_stats: bool,
|
||||
@ -100,6 +102,7 @@ class LLMEngine:
|
||||
f"kv_cache_dtype={cache_config.cache_dtype}, "
|
||||
f"quantization_param_path={model_config.quantization_param_path}, "
|
||||
f"device_config={device_config.device}, "
|
||||
f"decoding_config={decoding_config!r}, "
|
||||
f"seed={model_config.seed})")
|
||||
# TODO(woosuk): Print more configs in debug mode.
|
||||
|
||||
@ -111,6 +114,7 @@ class LLMEngine:
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.speculative_config = speculative_config
|
||||
self.decoding_config = decoding_config or DecodingConfig()
|
||||
self.tensorizer_config = tensorizer_config
|
||||
self.log_stats = log_stats
|
||||
|
||||
|
@ -133,6 +133,12 @@ class ChatCompletionRequest(BaseModel):
|
||||
description=(
|
||||
"If specified, the output will follow the context free grammar."),
|
||||
)
|
||||
guided_decoding_backend: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, will override the default guided decoding backend "
|
||||
"of the server for this specific request. If set, must be either "
|
||||
"'outlines' / 'lm-format-enforcer'"))
|
||||
|
||||
# doc: end-chat-completion-extra-params
|
||||
|
||||
@ -265,6 +271,12 @@ class CompletionRequest(BaseModel):
|
||||
description=(
|
||||
"If specified, the output will follow the context free grammar."),
|
||||
)
|
||||
guided_decoding_backend: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If specified, will override the default guided decoding backend "
|
||||
"of the server for this specific request. If set, must be one of "
|
||||
"'outlines' / 'lm-format-enforcer'"))
|
||||
|
||||
# doc: end-completion-extra-params
|
||||
|
||||
|
@ -68,9 +68,13 @@ class OpenAIServingChat(OpenAIServing):
|
||||
request, prompt=prompt)
|
||||
sampling_params = request.to_sampling_params()
|
||||
lora_request = self._maybe_get_lora(request)
|
||||
decoding_config = self.engine.engine.decoding_config
|
||||
guided_decoding_backend = request.guided_decoding_backend \
|
||||
or decoding_config.guided_decoding_backend
|
||||
guided_decode_logits_processor = (
|
||||
await get_guided_decoding_logits_processor(
|
||||
request, await self.engine.get_tokenizer()))
|
||||
guided_decoding_backend, request, await
|
||||
self.engine.get_tokenizer()))
|
||||
if guided_decode_logits_processor:
|
||||
if sampling_params.logits_processors is None:
|
||||
sampling_params.logits_processors = []
|
||||
|
@ -88,9 +88,13 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
try:
|
||||
sampling_params = request.to_sampling_params()
|
||||
lora_request = self._maybe_get_lora(request)
|
||||
decoding_config = self.engine.engine.decoding_config
|
||||
guided_decoding_backend = request.guided_decoding_backend \
|
||||
or decoding_config.guided_decoding_backend
|
||||
guided_decode_logit_processor = (
|
||||
await get_guided_decoding_logits_processor(
|
||||
request, await self.engine.get_tokenizer()))
|
||||
guided_decoding_backend, request, await
|
||||
self.engine.get_tokenizer()))
|
||||
if guided_decode_logit_processor is not None:
|
||||
if sampling_params.logits_processors is None:
|
||||
sampling_params.logits_processors = []
|
||||
|
25
vllm/model_executor/guided_decoding/__init__.py
Normal file
25
vllm/model_executor/guided_decoding/__init__.py
Normal file
@ -0,0 +1,25 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
CompletionRequest)
|
||||
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import (
|
||||
get_lm_format_enforcer_guided_decoding_logits_processor)
|
||||
from vllm.model_executor.guided_decoding.outlines_decoding import (
|
||||
get_outlines_guided_decoding_logits_processor)
|
||||
from vllm.sampling_params import LogitsProcessor
|
||||
|
||||
|
||||
async def get_guided_decoding_logits_processor(
|
||||
guided_decoding_backend: str, request: Union[CompletionRequest,
|
||||
ChatCompletionRequest],
|
||||
tokenizer) -> Optional[LogitsProcessor]:
|
||||
if guided_decoding_backend == 'outlines':
|
||||
return await get_outlines_guided_decoding_logits_processor(
|
||||
request, tokenizer)
|
||||
if guided_decoding_backend == 'lm-format-enforcer':
|
||||
return await get_lm_format_enforcer_guided_decoding_logits_processor(
|
||||
request, tokenizer)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
|
||||
"Must be one of 'outlines, 'lm-format-enforcer'")
|
@ -0,0 +1,69 @@
|
||||
from functools import lru_cache
|
||||
from json import loads as json_loads
|
||||
from typing import Optional, Union
|
||||
|
||||
from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser,
|
||||
RegexParser, StringParser,
|
||||
TokenEnforcerTokenizerData, UnionParser)
|
||||
from lmformatenforcer.integrations.vllm import (
|
||||
build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data)
|
||||
from pydantic import BaseModel
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
CompletionRequest)
|
||||
from vllm.model_executor.guided_decoding.outlines_decoding import (
|
||||
get_outlines_guided_decoding_logits_processor)
|
||||
from vllm.sampling_params import LogitsProcessor
|
||||
|
||||
|
||||
async def get_lm_format_enforcer_guided_decoding_logits_processor(
|
||||
request: Union[CompletionRequest, ChatCompletionRequest],
|
||||
tokenizer) -> Optional[LogitsProcessor]:
|
||||
"""
|
||||
Given an OpenAI-compatible request, check for guided decoding parameters
|
||||
and get the necessary logits processor for the given guide.
|
||||
We cache logit processors by (guide, tokenizer), and on cache hit
|
||||
we make a shallow copy to reuse the same underlying FSM.
|
||||
"""
|
||||
|
||||
tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
|
||||
tokenizer)
|
||||
character_level_parser: CharacterLevelParser
|
||||
if request.guided_json:
|
||||
schema = _normalize_json_schema_object(request.guided_json)
|
||||
character_level_parser = JsonSchemaParser(schema)
|
||||
elif request.guided_choice:
|
||||
character_level_parser = UnionParser(
|
||||
[StringParser(choice) for choice in request.guided_choice])
|
||||
elif request.guided_regex:
|
||||
character_level_parser = RegexParser(request.guided_regex)
|
||||
elif request.guided_grammar:
|
||||
# CFG grammar not supported by LMFE, revert to outlines
|
||||
return await get_outlines_guided_decoding_logits_processor(
|
||||
request, tokenizer)
|
||||
elif (request.response_format is not None
|
||||
and request.response_format.type == "json_object"):
|
||||
character_level_parser = JsonSchemaParser(
|
||||
None) # None means any json object
|
||||
else:
|
||||
return None
|
||||
|
||||
logits_processor = build_vllm_logits_processor(tokenizer_data,
|
||||
character_level_parser)
|
||||
return logits_processor
|
||||
|
||||
|
||||
def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
|
||||
if isinstance(schema, str):
|
||||
return json_loads(schema)
|
||||
if isinstance(schema, dict):
|
||||
return schema
|
||||
if isinstance(schema, BaseModel):
|
||||
return schema.model_json_schema()
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _cached_build_vllm_token_enforcer_tokenizer_data(
|
||||
tokenizer: PreTrainedTokenizerBase) -> TokenEnforcerTokenizerData:
|
||||
return build_vllm_token_enforcer_tokenizer_data(tokenizer)
|
@ -12,9 +12,8 @@ from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
CompletionRequest)
|
||||
from vllm.model_executor.guided_logits_processors import (CFGLogitsProcessor,
|
||||
JSONLogitsProcessor,
|
||||
RegexLogitsProcessor)
|
||||
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
||||
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
|
||||
|
||||
|
||||
class GuidedDecodingMode(Enum):
|
||||
@ -54,7 +53,7 @@ pair : UNESCAPED_STRING ":" value
|
||||
global_thread_pool = None # used for generating logits processor fsm
|
||||
|
||||
|
||||
async def get_guided_decoding_logits_processor(
|
||||
async def get_outlines_guided_decoding_logits_processor(
|
||||
request: Union[CompletionRequest, ChatCompletionRequest],
|
||||
tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
|
||||
"""
|
@ -13,9 +13,11 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from functools import lru_cache
|
||||
from typing import Callable, DefaultDict, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -27,50 +29,6 @@ from transformers import PreTrainedTokenizerBase
|
||||
|
||||
class BaseLogitsProcessor:
|
||||
|
||||
def adapt_tokenizer(self, tokenizer: PreTrainedTokenizerBase):
|
||||
"""Adapt vLLM's tokenizer to use to compile the FSM.
|
||||
|
||||
The API of Outlines tokenizers is slightly different to that of
|
||||
`transformers`. The decoder of outlines, returns a list whereas
|
||||
the decode of vLLM returns an str. To sync the vLLM decoder with
|
||||
outlines internal api, the decoder should be adapted. In addition
|
||||
we need to handle the missing spaces to Llama's tokenizer to be
|
||||
able to compile FSMs for this model.
|
||||
|
||||
"""
|
||||
if getattr(tokenizer, "_outlines_adapted", False):
|
||||
return tokenizer
|
||||
|
||||
tokenizer.vocabulary = tokenizer.get_vocab()
|
||||
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
|
||||
|
||||
def convert_token_to_string(token: str) -> str:
|
||||
from transformers.file_utils import SPIECE_UNDERLINE
|
||||
|
||||
string = tokenizer.convert_tokens_to_string([token])
|
||||
|
||||
# A hack to handle missing spaces to HF's Llama tokenizers
|
||||
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
||||
return " " + string
|
||||
|
||||
return string
|
||||
|
||||
def change_decoder(
|
||||
decoder: Callable[[List[int]], str]
|
||||
) -> Callable[[List[int]], List[str]]:
|
||||
"""Sync vLLM's decoder with the outlines by returning list."""
|
||||
|
||||
def new_decoder(inp_tokens: List[int]) -> List[str]:
|
||||
return [decoder(inp_tokens)]
|
||||
|
||||
return new_decoder
|
||||
|
||||
tokenizer.convert_token_to_string = convert_token_to_string
|
||||
tokenizer.decode = change_decoder(tokenizer.decode)
|
||||
setattr(tokenizer, "_outlines_adapted", True) # noqa: B010
|
||||
|
||||
return tokenizer
|
||||
|
||||
def init_state(self):
|
||||
"""Initialize the FSM states."""
|
||||
self.fsm_state: DefaultDict[int, int] = defaultdict(int)
|
||||
@ -78,7 +36,6 @@ class BaseLogitsProcessor:
|
||||
def __call__(self, input_ids: List[int],
|
||||
scores: torch.Tensor) -> torch.Tensor:
|
||||
"""Use the FSM to bias the logits before sampling the next token."""
|
||||
|
||||
seq_id = hash(tuple(input_ids))
|
||||
|
||||
if len(input_ids) == 0:
|
||||
@ -96,7 +53,6 @@ class BaseLogitsProcessor:
|
||||
device=scores.device)
|
||||
mask[allowed_tokens] = 0
|
||||
scores.add_(mask)
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
@ -113,7 +69,7 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
|
||||
The model's tokenizer
|
||||
|
||||
"""
|
||||
tokenizer = self.adapt_tokenizer(tokenizer)
|
||||
tokenizer = _adapt_tokenizer(tokenizer)
|
||||
fsm = RegexFSM(regex_string, tokenizer)
|
||||
self.fsm = fsm
|
||||
|
||||
@ -167,6 +123,54 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
|
||||
The model's tokenizer
|
||||
|
||||
"""
|
||||
tokenizer = self.adapt_tokenizer(tokenizer)
|
||||
tokenizer = _adapt_tokenizer(tokenizer)
|
||||
fsm = CFGFSM(cfg, tokenizer)
|
||||
self.fsm = fsm
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
|
||||
"""Adapt vLLM's tokenizer to use to compile the FSM.
|
||||
|
||||
The API of Outlines tokenizers is slightly different to that of
|
||||
`transformers`. The decoder of outlines, returns a list whereas
|
||||
the decode of vLLM returns an str. To sync the vLLM decoder with
|
||||
outlines internal api, the decoder should be adapted. In addition
|
||||
we need to handle the missing spaces to Llama's tokenizer to be
|
||||
able to compile FSMs for this model.
|
||||
|
||||
"""
|
||||
if getattr(tokenizer, "_outlines_adapted", False):
|
||||
return tokenizer
|
||||
|
||||
tokenizer = copy.deepcopy(tokenizer)
|
||||
|
||||
tokenizer.vocabulary = tokenizer.get_vocab()
|
||||
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
|
||||
|
||||
def convert_token_to_string(token: str) -> str:
|
||||
from transformers.file_utils import SPIECE_UNDERLINE
|
||||
|
||||
string = tokenizer.convert_tokens_to_string([token])
|
||||
|
||||
# A hack to handle missing spaces to HF's Llama tokenizers
|
||||
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
||||
return " " + string
|
||||
|
||||
return string
|
||||
|
||||
def change_decoder(
|
||||
decoder: Callable[[List[int]],
|
||||
str]) -> Callable[[List[int]], List[str]]:
|
||||
"""Sync vLLM's decoder with the outlines by returning list."""
|
||||
|
||||
def new_decoder(inp_tokens: List[int]) -> List[str]:
|
||||
return [decoder(inp_tokens)]
|
||||
|
||||
return new_decoder
|
||||
|
||||
tokenizer.convert_token_to_string = convert_token_to_string
|
||||
tokenizer.decode = change_decoder(tokenizer.decode)
|
||||
setattr(tokenizer, "_outlines_adapted", True) # noqa: B010
|
||||
|
||||
return tokenizer
|
Loading…
x
Reference in New Issue
Block a user