LM Format Enforcer Guided Decoding Support (#3868)
Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
parent
4e7ee664e2
commit
05434764cd
@ -11,6 +11,7 @@ uvicorn[standard]
|
|||||||
pydantic >= 2.0 # Required for OpenAI server.
|
pydantic >= 2.0 # Required for OpenAI server.
|
||||||
prometheus_client >= 0.18.0
|
prometheus_client >= 0.18.0
|
||||||
tiktoken == 0.6.0 # Required for DBRX tokenizer
|
tiktoken == 0.6.0 # Required for DBRX tokenizer
|
||||||
|
lm-format-enforcer == 0.9.3
|
||||||
outlines == 0.0.34 # Requires torch >= 2.1.0
|
outlines == 0.0.34 # Requires torch >= 2.1.0
|
||||||
typing_extensions
|
typing_extensions
|
||||||
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
|
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
|
||||||
|
@ -1,11 +1,14 @@
|
|||||||
# This unit test should be moved to a new
|
# This unit test should be moved to a new
|
||||||
# tests/test_guided_decoding directory.
|
# tests/test_guided_decoding directory.
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from vllm.model_executor.guided_logits_processors import (JSONLogitsProcessor,
|
from vllm.entrypoints.openai.protocol import CompletionRequest
|
||||||
RegexLogitsProcessor)
|
from vllm.model_executor.guided_decoding import (
|
||||||
|
get_guided_decoding_logits_processor)
|
||||||
|
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
||||||
|
JSONLogitsProcessor, RegexLogitsProcessor)
|
||||||
|
|
||||||
TEST_SCHEMA = {
|
TEST_SCHEMA = {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@ -73,3 +76,36 @@ def test_guided_logits_processors():
|
|||||||
json_LP(token_ids, tensor)
|
json_LP(token_ids, tensor)
|
||||||
assert tensor.shape == original_tensor.shape
|
assert tensor.shape == original_tensor.shape
|
||||||
assert not torch.allclose(tensor, original_tensor)
|
assert not torch.allclose(tensor, original_tensor)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"])
|
||||||
|
async def test_guided_logits_processor_black_box(backend: str):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta')
|
||||||
|
token_ids = tokenizer.encode(
|
||||||
|
f"Give an example IPv4 address with this regex: {TEST_REGEX}")
|
||||||
|
regex_request = CompletionRequest(model='test',
|
||||||
|
prompt=token_ids,
|
||||||
|
guided_regex=TEST_REGEX)
|
||||||
|
regex_lp = await get_guided_decoding_logits_processor(
|
||||||
|
backend, regex_request, tokenizer)
|
||||||
|
assert regex_lp is not None
|
||||||
|
tensor = torch.rand(32000)
|
||||||
|
original_tensor = torch.clone(tensor)
|
||||||
|
tensor = regex_lp(token_ids, tensor)
|
||||||
|
assert tensor.shape == original_tensor.shape
|
||||||
|
assert not torch.allclose(tensor, original_tensor)
|
||||||
|
|
||||||
|
token_ids = tokenizer.encode(
|
||||||
|
f"Give an employee profile that fits this schema: {TEST_SCHEMA}")
|
||||||
|
json_request = CompletionRequest(model='test',
|
||||||
|
prompt=token_ids,
|
||||||
|
guided_json=TEST_SCHEMA)
|
||||||
|
json_lp = await get_guided_decoding_logits_processor(
|
||||||
|
backend, json_request, tokenizer)
|
||||||
|
assert json_lp is not None
|
||||||
|
tensor = torch.rand(32000)
|
||||||
|
original_tensor = torch.clone(tensor)
|
||||||
|
tensor = json_lp(token_ids, tensor)
|
||||||
|
assert tensor.shape == original_tensor.shape
|
||||||
|
assert not torch.allclose(tensor, original_tensor)
|
||||||
|
@ -506,7 +506,10 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI):
|
|||||||
assert first_response != completion.choices[0].text
|
assert first_response != completion.choices[0].text
|
||||||
|
|
||||||
|
|
||||||
async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
|
@pytest.mark.parametrize("guided_decoding_backend",
|
||||||
|
["outlines", "lm-format-enforcer"])
|
||||||
|
async def test_guided_json_completion(server, client: openai.AsyncOpenAI,
|
||||||
|
guided_decoding_backend: str):
|
||||||
completion = await client.completions.create(
|
completion = await client.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
prompt=f"Give an example JSON for an employee profile "
|
prompt=f"Give an example JSON for an employee profile "
|
||||||
@ -514,7 +517,8 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
|
|||||||
n=3,
|
n=3,
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
max_tokens=500,
|
max_tokens=500,
|
||||||
extra_body=dict(guided_json=TEST_SCHEMA))
|
extra_body=dict(guided_json=TEST_SCHEMA,
|
||||||
|
guided_decoding_backend=guided_decoding_backend))
|
||||||
|
|
||||||
assert completion.id is not None
|
assert completion.id is not None
|
||||||
assert completion.choices is not None and len(completion.choices) == 3
|
assert completion.choices is not None and len(completion.choices) == 3
|
||||||
@ -524,7 +528,10 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI):
|
|||||||
jsonschema.validate(instance=output_json, schema=TEST_SCHEMA)
|
jsonschema.validate(instance=output_json, schema=TEST_SCHEMA)
|
||||||
|
|
||||||
|
|
||||||
async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
|
@pytest.mark.parametrize("guided_decoding_backend",
|
||||||
|
["outlines", "lm-format-enforcer"])
|
||||||
|
async def test_guided_json_chat(server, client: openai.AsyncOpenAI,
|
||||||
|
guided_decoding_backend: str):
|
||||||
messages = [{
|
messages = [{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": "you are a helpful assistant"
|
"content": "you are a helpful assistant"
|
||||||
@ -538,8 +545,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
|
|||||||
chat_completion = await client.chat.completions.create(
|
chat_completion = await client.chat.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=500,
|
max_tokens=1000,
|
||||||
extra_body=dict(guided_json=TEST_SCHEMA))
|
extra_body=dict(guided_json=TEST_SCHEMA,
|
||||||
|
guided_decoding_backend=guided_decoding_backend))
|
||||||
message = chat_completion.choices[0].message
|
message = chat_completion.choices[0].message
|
||||||
assert message.content is not None
|
assert message.content is not None
|
||||||
json1 = json.loads(message.content)
|
json1 = json.loads(message.content)
|
||||||
@ -555,8 +563,9 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
|
|||||||
chat_completion = await client.chat.completions.create(
|
chat_completion = await client.chat.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=500,
|
max_tokens=1000,
|
||||||
extra_body=dict(guided_json=TEST_SCHEMA))
|
extra_body=dict(guided_json=TEST_SCHEMA,
|
||||||
|
guided_decoding_backend=guided_decoding_backend))
|
||||||
message = chat_completion.choices[0].message
|
message = chat_completion.choices[0].message
|
||||||
assert message.content is not None
|
assert message.content is not None
|
||||||
json2 = json.loads(message.content)
|
json2 = json.loads(message.content)
|
||||||
@ -565,14 +574,18 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI):
|
|||||||
assert json1["age"] != json2["age"]
|
assert json1["age"] != json2["age"]
|
||||||
|
|
||||||
|
|
||||||
async def test_guided_regex_completion(server, client: openai.AsyncOpenAI):
|
@pytest.mark.parametrize("guided_decoding_backend",
|
||||||
|
["outlines", "lm-format-enforcer"])
|
||||||
|
async def test_guided_regex_completion(server, client: openai.AsyncOpenAI,
|
||||||
|
guided_decoding_backend: str):
|
||||||
completion = await client.completions.create(
|
completion = await client.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
prompt=f"Give an example IPv4 address with this regex: {TEST_REGEX}",
|
prompt=f"Give an example IPv4 address with this regex: {TEST_REGEX}",
|
||||||
n=3,
|
n=3,
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
max_tokens=20,
|
max_tokens=20,
|
||||||
extra_body=dict(guided_regex=TEST_REGEX))
|
extra_body=dict(guided_regex=TEST_REGEX,
|
||||||
|
guided_decoding_backend=guided_decoding_backend))
|
||||||
|
|
||||||
assert completion.id is not None
|
assert completion.id is not None
|
||||||
assert completion.choices is not None and len(completion.choices) == 3
|
assert completion.choices is not None and len(completion.choices) == 3
|
||||||
@ -581,7 +594,10 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI):
|
|||||||
assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None
|
assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None
|
||||||
|
|
||||||
|
|
||||||
async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
|
@pytest.mark.parametrize("guided_decoding_backend",
|
||||||
|
["outlines", "lm-format-enforcer"])
|
||||||
|
async def test_guided_regex_chat(server, client: openai.AsyncOpenAI,
|
||||||
|
guided_decoding_backend: str):
|
||||||
messages = [{
|
messages = [{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": "you are a helpful assistant"
|
"content": "you are a helpful assistant"
|
||||||
@ -595,7 +611,8 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
|
|||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=20,
|
max_tokens=20,
|
||||||
extra_body=dict(guided_regex=TEST_REGEX))
|
extra_body=dict(guided_regex=TEST_REGEX,
|
||||||
|
guided_decoding_backend=guided_decoding_backend))
|
||||||
ip1 = chat_completion.choices[0].message.content
|
ip1 = chat_completion.choices[0].message.content
|
||||||
assert ip1 is not None
|
assert ip1 is not None
|
||||||
assert re.fullmatch(TEST_REGEX, ip1) is not None
|
assert re.fullmatch(TEST_REGEX, ip1) is not None
|
||||||
@ -606,21 +623,26 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI):
|
|||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=20,
|
max_tokens=20,
|
||||||
extra_body=dict(guided_regex=TEST_REGEX))
|
extra_body=dict(guided_regex=TEST_REGEX,
|
||||||
|
guided_decoding_backend=guided_decoding_backend))
|
||||||
ip2 = chat_completion.choices[0].message.content
|
ip2 = chat_completion.choices[0].message.content
|
||||||
assert ip2 is not None
|
assert ip2 is not None
|
||||||
assert re.fullmatch(TEST_REGEX, ip2) is not None
|
assert re.fullmatch(TEST_REGEX, ip2) is not None
|
||||||
assert ip1 != ip2
|
assert ip1 != ip2
|
||||||
|
|
||||||
|
|
||||||
async def test_guided_choice_completion(server, client: openai.AsyncOpenAI):
|
@pytest.mark.parametrize("guided_decoding_backend",
|
||||||
|
["outlines", "lm-format-enforcer"])
|
||||||
|
async def test_guided_choice_completion(server, client: openai.AsyncOpenAI,
|
||||||
|
guided_decoding_backend: str):
|
||||||
completion = await client.completions.create(
|
completion = await client.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
prompt="The best language for type-safe systems programming is ",
|
prompt="The best language for type-safe systems programming is ",
|
||||||
n=2,
|
n=2,
|
||||||
temperature=1.0,
|
temperature=1.0,
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
extra_body=dict(guided_choice=TEST_CHOICE))
|
extra_body=dict(guided_choice=TEST_CHOICE,
|
||||||
|
guided_decoding_backend=guided_decoding_backend))
|
||||||
|
|
||||||
assert completion.id is not None
|
assert completion.id is not None
|
||||||
assert completion.choices is not None and len(completion.choices) == 2
|
assert completion.choices is not None and len(completion.choices) == 2
|
||||||
@ -628,7 +650,10 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI):
|
|||||||
assert completion.choices[i].text in TEST_CHOICE
|
assert completion.choices[i].text in TEST_CHOICE
|
||||||
|
|
||||||
|
|
||||||
async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
|
@pytest.mark.parametrize("guided_decoding_backend",
|
||||||
|
["outlines", "lm-format-enforcer"])
|
||||||
|
async def test_guided_choice_chat(server, client: openai.AsyncOpenAI,
|
||||||
|
guided_decoding_backend: str):
|
||||||
messages = [{
|
messages = [{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": "you are a helpful assistant"
|
"content": "you are a helpful assistant"
|
||||||
@ -642,7 +667,8 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
|
|||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
extra_body=dict(guided_choice=TEST_CHOICE))
|
extra_body=dict(guided_choice=TEST_CHOICE,
|
||||||
|
guided_decoding_backend=guided_decoding_backend))
|
||||||
choice1 = chat_completion.choices[0].message.content
|
choice1 = chat_completion.choices[0].message.content
|
||||||
assert choice1 in TEST_CHOICE
|
assert choice1 in TEST_CHOICE
|
||||||
|
|
||||||
@ -655,18 +681,23 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI):
|
|||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
extra_body=dict(guided_choice=TEST_CHOICE))
|
extra_body=dict(guided_choice=TEST_CHOICE,
|
||||||
|
guided_decoding_backend=guided_decoding_backend))
|
||||||
choice2 = chat_completion.choices[0].message.content
|
choice2 = chat_completion.choices[0].message.content
|
||||||
assert choice2 in TEST_CHOICE
|
assert choice2 in TEST_CHOICE
|
||||||
assert choice1 != choice2
|
assert choice1 != choice2
|
||||||
|
|
||||||
|
|
||||||
async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI):
|
@pytest.mark.parametrize("guided_decoding_backend",
|
||||||
|
["outlines", "lm-format-enforcer"])
|
||||||
|
async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI,
|
||||||
|
guided_decoding_backend: str):
|
||||||
with pytest.raises(openai.BadRequestError):
|
with pytest.raises(openai.BadRequestError):
|
||||||
_ = await client.completions.create(
|
_ = await client.completions.create(
|
||||||
model=MODEL_NAME,
|
model=MODEL_NAME,
|
||||||
prompt="Give an example JSON that fits this schema: 42",
|
prompt="Give an example JSON that fits this schema: 42",
|
||||||
extra_body=dict(guided_json=42))
|
extra_body=dict(guided_json=42,
|
||||||
|
guided_decoding_backend=guided_decoding_backend))
|
||||||
|
|
||||||
messages = [{
|
messages = [{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
|
@ -1079,6 +1079,21 @@ def _get_and_verify_max_len(
|
|||||||
return int(max_model_len)
|
return int(max_model_len)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DecodingConfig:
|
||||||
|
"""Dataclass which contains the decoding strategy of the engine"""
|
||||||
|
|
||||||
|
# Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer'
|
||||||
|
guided_decoding_backend: str = 'outlines'
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
valid_guided_backends = ['outlines', 'lm-format-enforcer']
|
||||||
|
backend = self.guided_decoding_backend
|
||||||
|
if backend not in valid_guided_backends:
|
||||||
|
raise ValueError(f"Invalid guided_decoding_backend '{backend},"
|
||||||
|
f"must be one of {valid_guided_backends}")
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class EngineConfig:
|
class EngineConfig:
|
||||||
"""Dataclass which contains all engine-related configuration. This
|
"""Dataclass which contains all engine-related configuration. This
|
||||||
@ -1093,6 +1108,7 @@ class EngineConfig:
|
|||||||
lora_config: Optional[LoRAConfig]
|
lora_config: Optional[LoRAConfig]
|
||||||
vision_language_config: Optional[VisionLanguageConfig]
|
vision_language_config: Optional[VisionLanguageConfig]
|
||||||
speculative_config: Optional[SpeculativeConfig]
|
speculative_config: Optional[SpeculativeConfig]
|
||||||
|
decoding_config: Optional[DecodingConfig]
|
||||||
tensorizer_config: Optional[TensorizerConfig]
|
tensorizer_config: Optional[TensorizerConfig]
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
@ -5,9 +5,9 @@ import os
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import BinaryIO, Optional, Union
|
from typing import BinaryIO, Optional, Union
|
||||||
|
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, EngineConfig, LoRAConfig,
|
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
||||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
EngineConfig, LoRAConfig, ModelConfig, ParallelConfig,
|
||||||
SpeculativeConfig, TensorizerConfig,
|
SchedulerConfig, SpeculativeConfig, TensorizerConfig,
|
||||||
TokenizerPoolConfig, VisionLanguageConfig)
|
TokenizerPoolConfig, VisionLanguageConfig)
|
||||||
from vllm.model_executor.tensorizer_loader import TensorizerArgs
|
from vllm.model_executor.tensorizer_loader import TensorizerArgs
|
||||||
from vllm.utils import str_to_int_tuple
|
from vllm.utils import str_to_int_tuple
|
||||||
@ -80,6 +80,7 @@ class EngineArgs:
|
|||||||
scheduler_delay_factor: float = 0.0
|
scheduler_delay_factor: float = 0.0
|
||||||
enable_chunked_prefill: bool = False
|
enable_chunked_prefill: bool = False
|
||||||
|
|
||||||
|
guided_decoding_backend: str = 'outlines'
|
||||||
# Speculative decoding configuration.
|
# Speculative decoding configuration.
|
||||||
speculative_model: Optional[str] = None
|
speculative_model: Optional[str] = None
|
||||||
num_speculative_tokens: Optional[int] = None
|
num_speculative_tokens: Optional[int] = None
|
||||||
@ -200,6 +201,13 @@ class EngineArgs:
|
|||||||
default=EngineArgs.max_model_len,
|
default=EngineArgs.max_model_len,
|
||||||
help='model context length. If unspecified, '
|
help='model context length. If unspecified, '
|
||||||
'will be automatically derived from the model.')
|
'will be automatically derived from the model.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--guided-decoding-backend',
|
||||||
|
type=str,
|
||||||
|
default='outlines',
|
||||||
|
choices=['outlines', 'lm-format-enforcer'],
|
||||||
|
help='Which engine will be used for guided decoding'
|
||||||
|
' (JSON schema / regex etc)')
|
||||||
# Parallel arguments
|
# Parallel arguments
|
||||||
parser.add_argument('--worker-use-ray',
|
parser.add_argument('--worker-use-ray',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
@ -511,6 +519,9 @@ class EngineArgs:
|
|||||||
else:
|
else:
|
||||||
vision_language_config = None
|
vision_language_config = None
|
||||||
|
|
||||||
|
decoding_config = DecodingConfig(
|
||||||
|
guided_decoding_backend=self.guided_decoding_backend)
|
||||||
|
|
||||||
return EngineConfig(model_config=model_config,
|
return EngineConfig(model_config=model_config,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
parallel_config=parallel_config,
|
parallel_config=parallel_config,
|
||||||
@ -519,6 +530,7 @@ class EngineArgs:
|
|||||||
lora_config=lora_config,
|
lora_config=lora_config,
|
||||||
vision_language_config=vision_language_config,
|
vision_language_config=vision_language_config,
|
||||||
speculative_config=speculative_config,
|
speculative_config=speculative_config,
|
||||||
|
decoding_config=decoding_config,
|
||||||
tensorizer_config=tensorizer_config)
|
tensorizer_config=tensorizer_config)
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,9 +4,10 @@ from typing import Iterable, List, Optional, Tuple, Type, Union
|
|||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
import vllm
|
import vllm
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig,
|
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoRAConfig,
|
||||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||||
TensorizerConfig, VisionLanguageConfig)
|
SpeculativeConfig, TensorizerConfig,
|
||||||
|
VisionLanguageConfig)
|
||||||
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.engine.metrics import StatLogger, Stats
|
from vllm.engine.metrics import StatLogger, Stats
|
||||||
@ -74,6 +75,7 @@ class LLMEngine:
|
|||||||
lora_config: Optional[LoRAConfig],
|
lora_config: Optional[LoRAConfig],
|
||||||
vision_language_config: Optional[VisionLanguageConfig],
|
vision_language_config: Optional[VisionLanguageConfig],
|
||||||
speculative_config: Optional[SpeculativeConfig],
|
speculative_config: Optional[SpeculativeConfig],
|
||||||
|
decoding_config: Optional[DecodingConfig],
|
||||||
tensorizer_config: Optional[TensorizerConfig],
|
tensorizer_config: Optional[TensorizerConfig],
|
||||||
executor_class: Type[ExecutorBase],
|
executor_class: Type[ExecutorBase],
|
||||||
log_stats: bool,
|
log_stats: bool,
|
||||||
@ -100,6 +102,7 @@ class LLMEngine:
|
|||||||
f"kv_cache_dtype={cache_config.cache_dtype}, "
|
f"kv_cache_dtype={cache_config.cache_dtype}, "
|
||||||
f"quantization_param_path={model_config.quantization_param_path}, "
|
f"quantization_param_path={model_config.quantization_param_path}, "
|
||||||
f"device_config={device_config.device}, "
|
f"device_config={device_config.device}, "
|
||||||
|
f"decoding_config={decoding_config!r}, "
|
||||||
f"seed={model_config.seed})")
|
f"seed={model_config.seed})")
|
||||||
# TODO(woosuk): Print more configs in debug mode.
|
# TODO(woosuk): Print more configs in debug mode.
|
||||||
|
|
||||||
@ -111,6 +114,7 @@ class LLMEngine:
|
|||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
self.device_config = device_config
|
self.device_config = device_config
|
||||||
self.speculative_config = speculative_config
|
self.speculative_config = speculative_config
|
||||||
|
self.decoding_config = decoding_config or DecodingConfig()
|
||||||
self.tensorizer_config = tensorizer_config
|
self.tensorizer_config = tensorizer_config
|
||||||
self.log_stats = log_stats
|
self.log_stats = log_stats
|
||||||
|
|
||||||
|
@ -133,6 +133,12 @@ class ChatCompletionRequest(BaseModel):
|
|||||||
description=(
|
description=(
|
||||||
"If specified, the output will follow the context free grammar."),
|
"If specified, the output will follow the context free grammar."),
|
||||||
)
|
)
|
||||||
|
guided_decoding_backend: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If specified, will override the default guided decoding backend "
|
||||||
|
"of the server for this specific request. If set, must be either "
|
||||||
|
"'outlines' / 'lm-format-enforcer'"))
|
||||||
|
|
||||||
# doc: end-chat-completion-extra-params
|
# doc: end-chat-completion-extra-params
|
||||||
|
|
||||||
@ -265,6 +271,12 @@ class CompletionRequest(BaseModel):
|
|||||||
description=(
|
description=(
|
||||||
"If specified, the output will follow the context free grammar."),
|
"If specified, the output will follow the context free grammar."),
|
||||||
)
|
)
|
||||||
|
guided_decoding_backend: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description=(
|
||||||
|
"If specified, will override the default guided decoding backend "
|
||||||
|
"of the server for this specific request. If set, must be one of "
|
||||||
|
"'outlines' / 'lm-format-enforcer'"))
|
||||||
|
|
||||||
# doc: end-completion-extra-params
|
# doc: end-completion-extra-params
|
||||||
|
|
||||||
|
@ -68,9 +68,13 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
request, prompt=prompt)
|
request, prompt=prompt)
|
||||||
sampling_params = request.to_sampling_params()
|
sampling_params = request.to_sampling_params()
|
||||||
lora_request = self._maybe_get_lora(request)
|
lora_request = self._maybe_get_lora(request)
|
||||||
|
decoding_config = self.engine.engine.decoding_config
|
||||||
|
guided_decoding_backend = request.guided_decoding_backend \
|
||||||
|
or decoding_config.guided_decoding_backend
|
||||||
guided_decode_logits_processor = (
|
guided_decode_logits_processor = (
|
||||||
await get_guided_decoding_logits_processor(
|
await get_guided_decoding_logits_processor(
|
||||||
request, await self.engine.get_tokenizer()))
|
guided_decoding_backend, request, await
|
||||||
|
self.engine.get_tokenizer()))
|
||||||
if guided_decode_logits_processor:
|
if guided_decode_logits_processor:
|
||||||
if sampling_params.logits_processors is None:
|
if sampling_params.logits_processors is None:
|
||||||
sampling_params.logits_processors = []
|
sampling_params.logits_processors = []
|
||||||
|
@ -88,9 +88,13 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
try:
|
try:
|
||||||
sampling_params = request.to_sampling_params()
|
sampling_params = request.to_sampling_params()
|
||||||
lora_request = self._maybe_get_lora(request)
|
lora_request = self._maybe_get_lora(request)
|
||||||
|
decoding_config = self.engine.engine.decoding_config
|
||||||
|
guided_decoding_backend = request.guided_decoding_backend \
|
||||||
|
or decoding_config.guided_decoding_backend
|
||||||
guided_decode_logit_processor = (
|
guided_decode_logit_processor = (
|
||||||
await get_guided_decoding_logits_processor(
|
await get_guided_decoding_logits_processor(
|
||||||
request, await self.engine.get_tokenizer()))
|
guided_decoding_backend, request, await
|
||||||
|
self.engine.get_tokenizer()))
|
||||||
if guided_decode_logit_processor is not None:
|
if guided_decode_logit_processor is not None:
|
||||||
if sampling_params.logits_processors is None:
|
if sampling_params.logits_processors is None:
|
||||||
sampling_params.logits_processors = []
|
sampling_params.logits_processors = []
|
||||||
|
25
vllm/model_executor/guided_decoding/__init__.py
Normal file
25
vllm/model_executor/guided_decoding/__init__.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
|
CompletionRequest)
|
||||||
|
from vllm.model_executor.guided_decoding.lm_format_enforcer_decoding import (
|
||||||
|
get_lm_format_enforcer_guided_decoding_logits_processor)
|
||||||
|
from vllm.model_executor.guided_decoding.outlines_decoding import (
|
||||||
|
get_outlines_guided_decoding_logits_processor)
|
||||||
|
from vllm.sampling_params import LogitsProcessor
|
||||||
|
|
||||||
|
|
||||||
|
async def get_guided_decoding_logits_processor(
|
||||||
|
guided_decoding_backend: str, request: Union[CompletionRequest,
|
||||||
|
ChatCompletionRequest],
|
||||||
|
tokenizer) -> Optional[LogitsProcessor]:
|
||||||
|
if guided_decoding_backend == 'outlines':
|
||||||
|
return await get_outlines_guided_decoding_logits_processor(
|
||||||
|
request, tokenizer)
|
||||||
|
if guided_decoding_backend == 'lm-format-enforcer':
|
||||||
|
return await get_lm_format_enforcer_guided_decoding_logits_processor(
|
||||||
|
request, tokenizer)
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown guided decoding backend '{guided_decoding_backend}'. "
|
||||||
|
"Must be one of 'outlines, 'lm-format-enforcer'")
|
@ -0,0 +1,69 @@
|
|||||||
|
from functools import lru_cache
|
||||||
|
from json import loads as json_loads
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from lmformatenforcer import (CharacterLevelParser, JsonSchemaParser,
|
||||||
|
RegexParser, StringParser,
|
||||||
|
TokenEnforcerTokenizerData, UnionParser)
|
||||||
|
from lmformatenforcer.integrations.vllm import (
|
||||||
|
build_vllm_logits_processor, build_vllm_token_enforcer_tokenizer_data)
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
|
CompletionRequest)
|
||||||
|
from vllm.model_executor.guided_decoding.outlines_decoding import (
|
||||||
|
get_outlines_guided_decoding_logits_processor)
|
||||||
|
from vllm.sampling_params import LogitsProcessor
|
||||||
|
|
||||||
|
|
||||||
|
async def get_lm_format_enforcer_guided_decoding_logits_processor(
|
||||||
|
request: Union[CompletionRequest, ChatCompletionRequest],
|
||||||
|
tokenizer) -> Optional[LogitsProcessor]:
|
||||||
|
"""
|
||||||
|
Given an OpenAI-compatible request, check for guided decoding parameters
|
||||||
|
and get the necessary logits processor for the given guide.
|
||||||
|
We cache logit processors by (guide, tokenizer), and on cache hit
|
||||||
|
we make a shallow copy to reuse the same underlying FSM.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tokenizer_data = _cached_build_vllm_token_enforcer_tokenizer_data(
|
||||||
|
tokenizer)
|
||||||
|
character_level_parser: CharacterLevelParser
|
||||||
|
if request.guided_json:
|
||||||
|
schema = _normalize_json_schema_object(request.guided_json)
|
||||||
|
character_level_parser = JsonSchemaParser(schema)
|
||||||
|
elif request.guided_choice:
|
||||||
|
character_level_parser = UnionParser(
|
||||||
|
[StringParser(choice) for choice in request.guided_choice])
|
||||||
|
elif request.guided_regex:
|
||||||
|
character_level_parser = RegexParser(request.guided_regex)
|
||||||
|
elif request.guided_grammar:
|
||||||
|
# CFG grammar not supported by LMFE, revert to outlines
|
||||||
|
return await get_outlines_guided_decoding_logits_processor(
|
||||||
|
request, tokenizer)
|
||||||
|
elif (request.response_format is not None
|
||||||
|
and request.response_format.type == "json_object"):
|
||||||
|
character_level_parser = JsonSchemaParser(
|
||||||
|
None) # None means any json object
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
logits_processor = build_vllm_logits_processor(tokenizer_data,
|
||||||
|
character_level_parser)
|
||||||
|
return logits_processor
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
|
||||||
|
if isinstance(schema, str):
|
||||||
|
return json_loads(schema)
|
||||||
|
if isinstance(schema, dict):
|
||||||
|
return schema
|
||||||
|
if isinstance(schema, BaseModel):
|
||||||
|
return schema.model_json_schema()
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def _cached_build_vllm_token_enforcer_tokenizer_data(
|
||||||
|
tokenizer: PreTrainedTokenizerBase) -> TokenEnforcerTokenizerData:
|
||||||
|
return build_vllm_token_enforcer_tokenizer_data(tokenizer)
|
@ -12,9 +12,8 @@ from transformers import PreTrainedTokenizerBase
|
|||||||
|
|
||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
CompletionRequest)
|
CompletionRequest)
|
||||||
from vllm.model_executor.guided_logits_processors import (CFGLogitsProcessor,
|
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
|
||||||
JSONLogitsProcessor,
|
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
|
||||||
RegexLogitsProcessor)
|
|
||||||
|
|
||||||
|
|
||||||
class GuidedDecodingMode(Enum):
|
class GuidedDecodingMode(Enum):
|
||||||
@ -54,7 +53,7 @@ pair : UNESCAPED_STRING ":" value
|
|||||||
global_thread_pool = None # used for generating logits processor fsm
|
global_thread_pool = None # used for generating logits processor fsm
|
||||||
|
|
||||||
|
|
||||||
async def get_guided_decoding_logits_processor(
|
async def get_outlines_guided_decoding_logits_processor(
|
||||||
request: Union[CompletionRequest, ChatCompletionRequest],
|
request: Union[CompletionRequest, ChatCompletionRequest],
|
||||||
tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
|
tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
|
||||||
"""
|
"""
|
@ -13,9 +13,11 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import copy
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from functools import lru_cache
|
||||||
from typing import Callable, DefaultDict, Dict, List, Optional, Union
|
from typing import Callable, DefaultDict, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -27,50 +29,6 @@ from transformers import PreTrainedTokenizerBase
|
|||||||
|
|
||||||
class BaseLogitsProcessor:
|
class BaseLogitsProcessor:
|
||||||
|
|
||||||
def adapt_tokenizer(self, tokenizer: PreTrainedTokenizerBase):
|
|
||||||
"""Adapt vLLM's tokenizer to use to compile the FSM.
|
|
||||||
|
|
||||||
The API of Outlines tokenizers is slightly different to that of
|
|
||||||
`transformers`. The decoder of outlines, returns a list whereas
|
|
||||||
the decode of vLLM returns an str. To sync the vLLM decoder with
|
|
||||||
outlines internal api, the decoder should be adapted. In addition
|
|
||||||
we need to handle the missing spaces to Llama's tokenizer to be
|
|
||||||
able to compile FSMs for this model.
|
|
||||||
|
|
||||||
"""
|
|
||||||
if getattr(tokenizer, "_outlines_adapted", False):
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
tokenizer.vocabulary = tokenizer.get_vocab()
|
|
||||||
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
|
|
||||||
|
|
||||||
def convert_token_to_string(token: str) -> str:
|
|
||||||
from transformers.file_utils import SPIECE_UNDERLINE
|
|
||||||
|
|
||||||
string = tokenizer.convert_tokens_to_string([token])
|
|
||||||
|
|
||||||
# A hack to handle missing spaces to HF's Llama tokenizers
|
|
||||||
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
|
||||||
return " " + string
|
|
||||||
|
|
||||||
return string
|
|
||||||
|
|
||||||
def change_decoder(
|
|
||||||
decoder: Callable[[List[int]], str]
|
|
||||||
) -> Callable[[List[int]], List[str]]:
|
|
||||||
"""Sync vLLM's decoder with the outlines by returning list."""
|
|
||||||
|
|
||||||
def new_decoder(inp_tokens: List[int]) -> List[str]:
|
|
||||||
return [decoder(inp_tokens)]
|
|
||||||
|
|
||||||
return new_decoder
|
|
||||||
|
|
||||||
tokenizer.convert_token_to_string = convert_token_to_string
|
|
||||||
tokenizer.decode = change_decoder(tokenizer.decode)
|
|
||||||
setattr(tokenizer, "_outlines_adapted", True) # noqa: B010
|
|
||||||
|
|
||||||
return tokenizer
|
|
||||||
|
|
||||||
def init_state(self):
|
def init_state(self):
|
||||||
"""Initialize the FSM states."""
|
"""Initialize the FSM states."""
|
||||||
self.fsm_state: DefaultDict[int, int] = defaultdict(int)
|
self.fsm_state: DefaultDict[int, int] = defaultdict(int)
|
||||||
@ -78,7 +36,6 @@ class BaseLogitsProcessor:
|
|||||||
def __call__(self, input_ids: List[int],
|
def __call__(self, input_ids: List[int],
|
||||||
scores: torch.Tensor) -> torch.Tensor:
|
scores: torch.Tensor) -> torch.Tensor:
|
||||||
"""Use the FSM to bias the logits before sampling the next token."""
|
"""Use the FSM to bias the logits before sampling the next token."""
|
||||||
|
|
||||||
seq_id = hash(tuple(input_ids))
|
seq_id = hash(tuple(input_ids))
|
||||||
|
|
||||||
if len(input_ids) == 0:
|
if len(input_ids) == 0:
|
||||||
@ -96,7 +53,6 @@ class BaseLogitsProcessor:
|
|||||||
device=scores.device)
|
device=scores.device)
|
||||||
mask[allowed_tokens] = 0
|
mask[allowed_tokens] = 0
|
||||||
scores.add_(mask)
|
scores.add_(mask)
|
||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
@ -113,7 +69,7 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
|
|||||||
The model's tokenizer
|
The model's tokenizer
|
||||||
|
|
||||||
"""
|
"""
|
||||||
tokenizer = self.adapt_tokenizer(tokenizer)
|
tokenizer = _adapt_tokenizer(tokenizer)
|
||||||
fsm = RegexFSM(regex_string, tokenizer)
|
fsm = RegexFSM(regex_string, tokenizer)
|
||||||
self.fsm = fsm
|
self.fsm = fsm
|
||||||
|
|
||||||
@ -167,6 +123,54 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
|
|||||||
The model's tokenizer
|
The model's tokenizer
|
||||||
|
|
||||||
"""
|
"""
|
||||||
tokenizer = self.adapt_tokenizer(tokenizer)
|
tokenizer = _adapt_tokenizer(tokenizer)
|
||||||
fsm = CFGFSM(cfg, tokenizer)
|
fsm = CFGFSM(cfg, tokenizer)
|
||||||
self.fsm = fsm
|
self.fsm = fsm
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
|
||||||
|
"""Adapt vLLM's tokenizer to use to compile the FSM.
|
||||||
|
|
||||||
|
The API of Outlines tokenizers is slightly different to that of
|
||||||
|
`transformers`. The decoder of outlines, returns a list whereas
|
||||||
|
the decode of vLLM returns an str. To sync the vLLM decoder with
|
||||||
|
outlines internal api, the decoder should be adapted. In addition
|
||||||
|
we need to handle the missing spaces to Llama's tokenizer to be
|
||||||
|
able to compile FSMs for this model.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if getattr(tokenizer, "_outlines_adapted", False):
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
tokenizer = copy.deepcopy(tokenizer)
|
||||||
|
|
||||||
|
tokenizer.vocabulary = tokenizer.get_vocab()
|
||||||
|
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
|
||||||
|
|
||||||
|
def convert_token_to_string(token: str) -> str:
|
||||||
|
from transformers.file_utils import SPIECE_UNDERLINE
|
||||||
|
|
||||||
|
string = tokenizer.convert_tokens_to_string([token])
|
||||||
|
|
||||||
|
# A hack to handle missing spaces to HF's Llama tokenizers
|
||||||
|
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
||||||
|
return " " + string
|
||||||
|
|
||||||
|
return string
|
||||||
|
|
||||||
|
def change_decoder(
|
||||||
|
decoder: Callable[[List[int]],
|
||||||
|
str]) -> Callable[[List[int]], List[str]]:
|
||||||
|
"""Sync vLLM's decoder with the outlines by returning list."""
|
||||||
|
|
||||||
|
def new_decoder(inp_tokens: List[int]) -> List[str]:
|
||||||
|
return [decoder(inp_tokens)]
|
||||||
|
|
||||||
|
return new_decoder
|
||||||
|
|
||||||
|
tokenizer.convert_token_to_string = convert_token_to_string
|
||||||
|
tokenizer.decode = change_decoder(tokenizer.decode)
|
||||||
|
setattr(tokenizer, "_outlines_adapted", True) # noqa: B010
|
||||||
|
|
||||||
|
return tokenizer
|
Loading…
x
Reference in New Issue
Block a user