vllm/vllm/entrypoints/openai/protocol.py
Noam Gat 05434764cd
LM Format Enforcer Guided Decoding Support (#3868)
Co-authored-by: Simon Mo <simon.mo@hey.com>
2024-04-16 05:54:57 +00:00

438 lines
16 KiB
Python

# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time
from typing import Dict, List, Literal, Optional, Union
import torch
from pydantic import BaseModel, Field, conint, model_validator
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
class ErrorResponse(BaseModel):
object: str = "error"
message: str
type: str
param: Optional[str] = None
code: int
class ModelPermission(BaseModel):
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
object: str = "model_permission"
created: int = Field(default_factory=lambda: int(time.time()))
allow_create_engine: bool = False
allow_sampling: bool = True
allow_logprobs: bool = True
allow_search_indices: bool = False
allow_view: bool = True
allow_fine_tuning: bool = False
organization: str = "*"
group: Optional[str] = None
is_blocking: str = False
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "vllm"
root: Optional[str] = None
parent: Optional[str] = None
permission: List[ModelPermission] = Field(default_factory=list)
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = Field(default_factory=list)
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class ResponseFormat(BaseModel):
# type must be "json_object" or "text"
type: str = Literal["text", "json_object"]
class ChatCompletionRequest(BaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
messages: List[Dict[str, str]]
model: str
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = None
max_tokens: Optional[int] = None
n: Optional[int] = 1
presence_penalty: Optional[float] = 0.0
response_format: Optional[ResponseFormat] = None
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
user: Optional[str] = None
# doc: begin-chat-completion-sampling-params
best_of: Optional[int] = None
use_beam_search: Optional[bool] = False
top_k: Optional[int] = -1
min_p: Optional[float] = 0.0
repetition_penalty: Optional[float] = 1.0
length_penalty: Optional[float] = 1.0
early_stopping: Optional[bool] = False
ignore_eos: Optional[bool] = False
min_tokens: Optional[int] = 0
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
# doc: end-chat-completion-sampling-params
# doc: begin-chat-completion-extra-params
echo: Optional[bool] = Field(
default=False,
description=(
"If true, the new message will be prepended with the last message "
"if they belong to the same role."),
)
add_generation_prompt: Optional[bool] = Field(
default=True,
description=
("If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."),
)
include_stop_str_in_output: Optional[bool] = Field(
default=False,
description=(
"Whether to include the stop string in the output. "
"This is only applied when the stop or stop_token_ids is set."),
)
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None,
description=("If specified, the output will follow the JSON schema."),
)
guided_regex: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the regex pattern."),
)
guided_choice: Optional[List[str]] = Field(
default=None,
description=(
"If specified, the output will be exactly one of the choices."),
)
guided_grammar: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the context free grammar."),
)
guided_decoding_backend: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be either "
"'outlines' / 'lm-format-enforcer'"))
# doc: end-chat-completion-extra-params
def to_sampling_params(self) -> SamplingParams:
if self.logprobs and not self.top_logprobs:
raise ValueError("Top logprobs must be set when logprobs is.")
logits_processors = None
if self.logit_bias:
def logit_bias_logits_processor(
token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
for token_id, bias in self.logit_bias.items():
# Clamp the bias between -100 and 100 per OpenAI API spec
bias = min(100, max(-100, bias))
logits[int(token_id)] += bias
return logits
logits_processors = [logit_bias_logits_processor]
return SamplingParams(
n=self.n,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
min_p=self.min_p,
seed=self.seed,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
max_tokens=self.max_tokens,
min_tokens=self.min_tokens,
logprobs=self.top_logprobs if self.logprobs else None,
prompt_logprobs=self.top_logprobs if self.echo else None,
best_of=self.best_of,
top_k=self.top_k,
ignore_eos=self.ignore_eos,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
)
@model_validator(mode="before")
@classmethod
def check_guided_decoding_count(cls, data):
guide_count = sum([
"guided_json" in data and data["guided_json"] is not None,
"guided_regex" in data and data["guided_regex"] is not None,
"guided_choice" in data and data["guided_choice"] is not None
])
if guide_count > 1:
raise ValueError(
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
return data
class CompletionRequest(BaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/completions/create
model: str
prompt: Union[List[int], List[List[int]], str, List[str]]
best_of: Optional[int] = None
echo: Optional[bool] = False
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[int] = None
max_tokens: Optional[int] = 16
n: Optional[int] = 1
presence_penalty: Optional[float] = 0.0
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
suffix: Optional[str] = None
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
user: Optional[str] = None
# doc: begin-completion-sampling-params
use_beam_search: Optional[bool] = False
top_k: Optional[int] = -1
min_p: Optional[float] = 0.0
repetition_penalty: Optional[float] = 1.0
length_penalty: Optional[float] = 1.0
early_stopping: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
ignore_eos: Optional[bool] = False
min_tokens: Optional[int] = 0
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
truncate_prompt_tokens: Optional[conint(ge=1)] = None
# doc: end-completion-sampling-params
# doc: begin-completion-extra-params
include_stop_str_in_output: Optional[bool] = Field(
default=False,
description=(
"Whether to include the stop string in the output. "
"This is only applied when the stop or stop_token_ids is set."),
)
response_format: Optional[ResponseFormat] = Field(
default=None,
description=
("Similar to chat completion, this parameter specifies the format of "
"output. Only {'type': 'json_object'} or {'type': 'text' } is "
"supported."),
)
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None,
description=("If specified, the output will follow the JSON schema."),
)
guided_regex: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the regex pattern."),
)
guided_choice: Optional[List[str]] = Field(
default=None,
description=(
"If specified, the output will be exactly one of the choices."),
)
guided_grammar: Optional[str] = Field(
default=None,
description=(
"If specified, the output will follow the context free grammar."),
)
guided_decoding_backend: Optional[str] = Field(
default=None,
description=(
"If specified, will override the default guided decoding backend "
"of the server for this specific request. If set, must be one of "
"'outlines' / 'lm-format-enforcer'"))
# doc: end-completion-extra-params
def to_sampling_params(self):
echo_without_generation = self.echo and self.max_tokens == 0
logits_processors = None
if self.logit_bias:
def logit_bias_logits_processor(
token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
for token_id, bias in self.logit_bias.items():
# Clamp the bias between -100 and 100 per OpenAI API spec
bias = min(100, max(-100, bias))
logits[int(token_id)] += bias
return logits
logits_processors = [logit_bias_logits_processor]
return SamplingParams(
n=self.n,
best_of=self.best_of,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
min_p=self.min_p,
seed=self.seed,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
ignore_eos=self.ignore_eos,
max_tokens=self.max_tokens if not echo_without_generation else 1,
min_tokens=self.min_tokens,
logprobs=self.logprobs,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
prompt_logprobs=self.logprobs if self.echo else None,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=(self.spaces_between_special_tokens),
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens,
)
@model_validator(mode="before")
@classmethod
def check_guided_decoding_count(cls, data):
guide_count = sum([
"guided_json" in data and data["guided_json"] is not None,
"guided_regex" in data and data["guided_regex"] is not None,
"guided_choice" in data and data["guided_choice"] is not None
])
if guide_count > 1:
raise ValueError(
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
return data
class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None
class CompletionResponseChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
stop_reason: Union[None, int, str] = Field(
default=None,
description=(
"The stop string or token id that caused the completion "
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"),
)
class CompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseChoice]
usage: UsageInfo
class CompletionResponseStreamChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
stop_reason: Union[None, int, str] = Field(
default=None,
description=(
"The stop string or token id that caused the completion "
"to stop, None if the completion finished for some other reason "
"including encountering the EOS token"),
)
class CompletionStreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseStreamChoice]
usage: Optional[UsageInfo] = Field(default=None)
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
stop_reason: Union[None, int, str] = None
class ChatCompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
stop_reason: Union[None, int, str] = None
class ChatCompletionStreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]
usage: Optional[UsageInfo] = Field(default=None)