[Bugfix][Structured Output] Support outlines engine with reasoning outputs for DeepSeek R1 (#14114)
This commit is contained in:
parent
abcc61e0af
commit
f5f7f00cd9
@ -10,7 +10,9 @@ Reasoning models return a additional `reasoning_content` field in their outputs,
|
|||||||
|
|
||||||
vLLM currently supports the following reasoning models:
|
vLLM currently supports the following reasoning models:
|
||||||
|
|
||||||
- [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) (`deepseek_r1`, which looks for `<think> ... </think>`)
|
| Model Series | Parser Name | Structured Output Support |
|
||||||
|
|--------------|-------------|------------------|
|
||||||
|
| [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `guided_json`, `guided_regex` |
|
||||||
|
|
||||||
## Quickstart
|
## Quickstart
|
||||||
|
|
||||||
@ -78,11 +80,51 @@ Streaming chat completions are also supported for reasoning models. The `reasoni
|
|||||||
|
|
||||||
Please note that it is not compatible with the OpenAI Python client library. You can use the `requests` library to make streaming requests. You could checkout the [example](https://github.com/vllm-project/vllm/blob/main/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py).
|
Please note that it is not compatible with the OpenAI Python client library. You can use the `requests` library to make streaming requests. You could checkout the [example](https://github.com/vllm-project/vllm/blob/main/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py).
|
||||||
|
|
||||||
|
## Structured output
|
||||||
|
|
||||||
|
The reasoning content is also available in the structured output. The structured output engine like `xgrammar` will use the reasoning content to generate structured output.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from openai import OpenAI
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||||
|
openai_api_key = "EMPTY"
|
||||||
|
openai_api_base = "http://localhost:8000/v1"
|
||||||
|
|
||||||
|
client = OpenAI(
|
||||||
|
api_key=openai_api_key,
|
||||||
|
base_url=openai_api_base,
|
||||||
|
)
|
||||||
|
|
||||||
|
models = client.models.list()
|
||||||
|
model = models.data[0].id
|
||||||
|
|
||||||
|
|
||||||
|
class People(BaseModel):
|
||||||
|
name: str
|
||||||
|
age: int
|
||||||
|
|
||||||
|
|
||||||
|
json_schema = People.model_json_schema()
|
||||||
|
|
||||||
|
prompt = ("Generate a JSON with the name and age of one random person.")
|
||||||
|
completion = client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt,
|
||||||
|
}],
|
||||||
|
extra_body={"guided_json": json_schema},
|
||||||
|
)
|
||||||
|
print("reasoning_content: ", completion.choices[0].message.reasoning_content)
|
||||||
|
print("content: ", completion.choices[0].message.content)
|
||||||
|
```
|
||||||
|
|
||||||
## Limitations
|
## Limitations
|
||||||
|
|
||||||
- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`).
|
- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`).
|
||||||
- It is not compatible with [`tool_calling`](#tool_calling).
|
- It is not compatible with [`tool_calling`](#tool_calling).
|
||||||
- The reasoning content is not available for all models. Check the model's documentation to see if it supports reasoning.
|
|
||||||
|
|
||||||
## How to support a new reasoning model
|
## How to support a new reasoning model
|
||||||
|
|
||||||
@ -166,9 +208,10 @@ class DeepSeekReasoner(Reasoner):
|
|||||||
|
|
||||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||||
return self.end_token_id in input_ids
|
return self.end_token_id in input_ids
|
||||||
|
...
|
||||||
```
|
```
|
||||||
|
|
||||||
The structured output engine like xgrammar will use `end_token_id` to check if the reasoning content is present in the model output and skip the structured output if it is the case.
|
The structured output engine like `xgrammar` will use `end_token_id` to check if the reasoning content is present in the model output and skip the structured output if it is the case.
|
||||||
|
|
||||||
Finally, you can enable reasoning for the model by using the `--enable-reasoning` and `--reasoning-parser` flags.
|
Finally, you can enable reasoning for the model by using the `--enable-reasoning` and `--reasoning-parser` flags.
|
||||||
|
|
||||||
|
@ -33,6 +33,42 @@ client = OpenAI(
|
|||||||
models = client.models.list()
|
models = client.models.list()
|
||||||
model = models.data[0].id
|
model = models.data[0].id
|
||||||
|
|
||||||
|
# Guided decoding by Regex
|
||||||
|
prompt = ("What is the capital of France?")
|
||||||
|
|
||||||
|
completion = client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt,
|
||||||
|
}],
|
||||||
|
extra_body={
|
||||||
|
"guided_regex": "(Paris|London)",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
print("reasoning_content: ", completion.choices[0].message.reasoning_content)
|
||||||
|
print("content: ", completion.choices[0].message.content)
|
||||||
|
|
||||||
|
|
||||||
|
class People(BaseModel):
|
||||||
|
name: str
|
||||||
|
age: int
|
||||||
|
|
||||||
|
|
||||||
|
json_schema = People.model_json_schema()
|
||||||
|
|
||||||
|
prompt = ("Generate a JSON with the name and age of one random person.")
|
||||||
|
completion = client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt,
|
||||||
|
}],
|
||||||
|
extra_body={"guided_json": json_schema},
|
||||||
|
)
|
||||||
|
print("reasoning_content: ", completion.choices[0].message.reasoning_content)
|
||||||
|
print("content: ", completion.choices[0].message.content)
|
||||||
|
|
||||||
|
|
||||||
# Guided decoding by JSON using Pydantic schema
|
# Guided decoding by JSON using Pydantic schema
|
||||||
class CarType(str, Enum):
|
class CarType(str, Enum):
|
||||||
@ -51,7 +87,7 @@ class CarDescription(BaseModel):
|
|||||||
json_schema = CarDescription.model_json_schema()
|
json_schema = CarDescription.model_json_schema()
|
||||||
|
|
||||||
prompt = ("Generate a JSON with the brand, model and car_type of"
|
prompt = ("Generate a JSON with the brand, model and car_type of"
|
||||||
"the most iconic car from the 90's, think in 100 tokens")
|
"the most iconic car from the 90's")
|
||||||
completion = client.chat.completions.create(
|
completion = client.chat.completions.create(
|
||||||
model=model,
|
model=model,
|
||||||
messages=[{
|
messages=[{
|
||||||
@ -60,5 +96,34 @@ completion = client.chat.completions.create(
|
|||||||
}],
|
}],
|
||||||
extra_body={"guided_json": json_schema},
|
extra_body={"guided_json": json_schema},
|
||||||
)
|
)
|
||||||
print("content", completion.choices[0].message.content)
|
|
||||||
print("reasoning_content: ", completion.choices[0].message.reasoning_content)
|
print("reasoning_content: ", completion.choices[0].message.reasoning_content)
|
||||||
|
print("content: ", completion.choices[0].message.content)
|
||||||
|
|
||||||
|
# Guided decoding by Grammar
|
||||||
|
simplified_sql_grammar = """
|
||||||
|
?start: select_statement
|
||||||
|
|
||||||
|
?select_statement: "SELECT " column_list " FROM " table_name
|
||||||
|
|
||||||
|
?column_list: column_name ("," column_name)*
|
||||||
|
|
||||||
|
?table_name: identifier
|
||||||
|
|
||||||
|
?column_name: identifier
|
||||||
|
|
||||||
|
?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/
|
||||||
|
"""
|
||||||
|
|
||||||
|
# This may be very slow https://github.com/vllm-project/vllm/issues/12122
|
||||||
|
prompt = ("Generate an SQL query to show the 'username' and 'email'"
|
||||||
|
"from the 'users' table.")
|
||||||
|
completion = client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt,
|
||||||
|
}],
|
||||||
|
extra_body={"guided_grammar": simplified_sql_grammar},
|
||||||
|
)
|
||||||
|
print("reasoning_content: ", completion.choices[0].message.reasoning_content)
|
||||||
|
print("content: ", completion.choices[0].message.content)
|
||||||
|
@ -112,6 +112,7 @@ async def get_guided_decoding_logits_processor(
|
|||||||
reasoner = get_reasoner(tokenizer, reasoning_backend)
|
reasoner = get_reasoner(tokenizer, reasoning_backend)
|
||||||
|
|
||||||
guided_params = maybe_backend_fallback(guided_params)
|
guided_params = maybe_backend_fallback(guided_params)
|
||||||
|
|
||||||
# CFG grammar not supported by LMFE, so we use outlines instead
|
# CFG grammar not supported by LMFE, so we use outlines instead
|
||||||
if guided_params.backend_name == 'outlines':
|
if guided_params.backend_name == 'outlines':
|
||||||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||||
|
@ -43,7 +43,7 @@ class BaseLogitsProcessor:
|
|||||||
|
|
||||||
def __init__(self, guide: Guide, reasoner: Optional[Reasoner]):
|
def __init__(self, guide: Guide, reasoner: Optional[Reasoner]):
|
||||||
self._guide: Guide = guide
|
self._guide: Guide = guide
|
||||||
self._reasoner = reasoner
|
self._reasoner: Optional[Reasoner] = reasoner
|
||||||
# CFGState is used for the FSM state for CFGGuide
|
# CFGState is used for the FSM state for CFGGuide
|
||||||
self._fsm_state: DefaultDict[int, Union[int,
|
self._fsm_state: DefaultDict[int, Union[int,
|
||||||
CFGState]] = defaultdict(int)
|
CFGState]] = defaultdict(int)
|
||||||
@ -54,10 +54,14 @@ class BaseLogitsProcessor:
|
|||||||
|
|
||||||
# Skip the structured logits processing if reasoning is not finished.
|
# Skip the structured logits processing if reasoning is not finished.
|
||||||
# reasoner is not None only when `--enable-reasoning` is set.
|
# reasoner is not None only when `--enable-reasoning` is set.
|
||||||
if self._reasoner is not None and \
|
if self._reasoner is not None:
|
||||||
not self._reasoner.is_reasoning_end(
|
if not self._reasoner.is_reasoning_end(input_ids):
|
||||||
input_ids):
|
return scores
|
||||||
return scores
|
else:
|
||||||
|
# Remove the reasoning tokens from the input_ids
|
||||||
|
# We need this because our implementation relies on the
|
||||||
|
# hash of the input_ids to store the FSM state.
|
||||||
|
input_ids = self._reasoner.extract_content(input_ids)
|
||||||
|
|
||||||
seq_id = hash(tuple(input_ids))
|
seq_id = hash(tuple(input_ids))
|
||||||
|
|
||||||
|
@ -4,10 +4,13 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.guided_decoding.reasoner.deepseek_reasoner import ( # noqa: E501
|
from vllm.model_executor.guided_decoding.reasoner.deepseek_reasoner import ( # noqa: E501
|
||||||
DeepSeekReasoner)
|
DeepSeekReasoner)
|
||||||
from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner
|
from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_reasoner(tokenizer: PreTrainedTokenizer,
|
def get_reasoner(tokenizer: PreTrainedTokenizer,
|
||||||
reasoning_backend: str | None) -> Reasoner | None:
|
reasoning_backend: str | None) -> Reasoner | None:
|
||||||
@ -17,7 +20,12 @@ def get_reasoner(tokenizer: PreTrainedTokenizer,
|
|||||||
elif reasoning_backend == "deepseek_r1":
|
elif reasoning_backend == "deepseek_r1":
|
||||||
return DeepSeekReasoner.from_tokenizer(tokenizer)
|
return DeepSeekReasoner.from_tokenizer(tokenizer)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown reasoning backend '{reasoning_backend}'")
|
# Raise a warning for unknown reasoning backend and return None
|
||||||
|
# We cannot raise an error here because some reasoning models
|
||||||
|
# may not have a corresponding Reasoner class.
|
||||||
|
logger.warning("Unknown reasoning backend %s for structured outputs ",
|
||||||
|
reasoning_backend)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["Reasoner", "get_reasoner"]
|
__all__ = ["Reasoner", "get_reasoner"]
|
||||||
|
@ -26,3 +26,13 @@ class DeepSeekReasoner(Reasoner):
|
|||||||
|
|
||||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||||
return self.end_token_id in input_ids
|
return self.end_token_id in input_ids
|
||||||
|
|
||||||
|
def extract_content(self, input_ids: list[int]) -> list[int]:
|
||||||
|
"""
|
||||||
|
Extract the content after the end tokens
|
||||||
|
"""
|
||||||
|
if self.end_token_id not in input_ids or \
|
||||||
|
input_ids.index(self.end_token_id) + 1 == len(input_ids):
|
||||||
|
return []
|
||||||
|
else:
|
||||||
|
return input_ids[input_ids.index(self.end_token_id) + 1:]
|
||||||
|
@ -17,3 +17,7 @@ class Reasoner(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def extract_content(self, input_ids: list[int]) -> list[int]:
|
||||||
|
pass
|
||||||
|
@ -392,7 +392,7 @@ class XGrammarLogitsProcessor:
|
|||||||
def clone(self) -> XGrammarLogitsProcessor:
|
def clone(self) -> XGrammarLogitsProcessor:
|
||||||
"""Create a new instance with shared compiled grammar
|
"""Create a new instance with shared compiled grammar
|
||||||
but separate state"""
|
but separate state"""
|
||||||
new_processor = XGrammarLogitsProcessor(self.config)
|
new_processor = XGrammarLogitsProcessor(self.config, self.reasoner)
|
||||||
|
|
||||||
# Share the compiled grammar context (immutable after compilation)
|
# Share the compiled grammar context (immutable after compilation)
|
||||||
new_processor.ctx = self.ctx
|
new_processor.ctx = self.ctx
|
||||||
|
Loading…
x
Reference in New Issue
Block a user