[Bugfix][Structured Output] Support outlines engine with reasoning outputs for DeepSeek R1 (#14114)

This commit is contained in:
Ce Gao 2025-03-06 11:49:20 +08:00 committed by GitHub
parent abcc61e0af
commit f5f7f00cd9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 147 additions and 12 deletions

View File

@ -10,7 +10,9 @@ Reasoning models return a additional `reasoning_content` field in their outputs,
vLLM currently supports the following reasoning models:
- [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) (`deepseek_r1`, which looks for `<think> ... </think>`)
| Model Series | Parser Name | Structured Output Support |
|--------------|-------------|------------------|
| [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `guided_json`, `guided_regex` |
## Quickstart
@ -78,11 +80,51 @@ Streaming chat completions are also supported for reasoning models. The `reasoni
Please note that it is not compatible with the OpenAI Python client library. You can use the `requests` library to make streaming requests. You could checkout the [example](https://github.com/vllm-project/vllm/blob/main/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py).
## Structured output
The reasoning content is also available in the structured output. The structured output engine like `xgrammar` will use the reasoning content to generate structured output.
```python
from openai import OpenAI
from pydantic import BaseModel
# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
models = client.models.list()
model = models.data[0].id
class People(BaseModel):
name: str
age: int
json_schema = People.model_json_schema()
prompt = ("Generate a JSON with the name and age of one random person.")
completion = client.chat.completions.create(
model=model,
messages=[{
"role": "user",
"content": prompt,
}],
extra_body={"guided_json": json_schema},
)
print("reasoning_content: ", completion.choices[0].message.reasoning_content)
print("content: ", completion.choices[0].message.content)
```
## Limitations
- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`).
- It is not compatible with [`tool_calling`](#tool_calling).
- The reasoning content is not available for all models. Check the model's documentation to see if it supports reasoning.
## How to support a new reasoning model
@ -166,9 +208,10 @@ class DeepSeekReasoner(Reasoner):
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.end_token_id in input_ids
...
```
The structured output engine like xgrammar will use `end_token_id` to check if the reasoning content is present in the model output and skip the structured output if it is the case.
The structured output engine like `xgrammar` will use `end_token_id` to check if the reasoning content is present in the model output and skip the structured output if it is the case.
Finally, you can enable reasoning for the model by using the `--enable-reasoning` and `--reasoning-parser` flags.

View File

@ -33,6 +33,42 @@ client = OpenAI(
models = client.models.list()
model = models.data[0].id
# Guided decoding by Regex
prompt = ("What is the capital of France?")
completion = client.chat.completions.create(
model=model,
messages=[{
"role": "user",
"content": prompt,
}],
extra_body={
"guided_regex": "(Paris|London)",
},
)
print("reasoning_content: ", completion.choices[0].message.reasoning_content)
print("content: ", completion.choices[0].message.content)
class People(BaseModel):
name: str
age: int
json_schema = People.model_json_schema()
prompt = ("Generate a JSON with the name and age of one random person.")
completion = client.chat.completions.create(
model=model,
messages=[{
"role": "user",
"content": prompt,
}],
extra_body={"guided_json": json_schema},
)
print("reasoning_content: ", completion.choices[0].message.reasoning_content)
print("content: ", completion.choices[0].message.content)
# Guided decoding by JSON using Pydantic schema
class CarType(str, Enum):
@ -51,7 +87,7 @@ class CarDescription(BaseModel):
json_schema = CarDescription.model_json_schema()
prompt = ("Generate a JSON with the brand, model and car_type of"
"the most iconic car from the 90's, think in 100 tokens")
"the most iconic car from the 90's")
completion = client.chat.completions.create(
model=model,
messages=[{
@ -60,5 +96,34 @@ completion = client.chat.completions.create(
}],
extra_body={"guided_json": json_schema},
)
print("content", completion.choices[0].message.content)
print("reasoning_content: ", completion.choices[0].message.reasoning_content)
print("content: ", completion.choices[0].message.content)
# Guided decoding by Grammar
simplified_sql_grammar = """
?start: select_statement
?select_statement: "SELECT " column_list " FROM " table_name
?column_list: column_name ("," column_name)*
?table_name: identifier
?column_name: identifier
?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/
"""
# This may be very slow https://github.com/vllm-project/vllm/issues/12122
prompt = ("Generate an SQL query to show the 'username' and 'email'"
"from the 'users' table.")
completion = client.chat.completions.create(
model=model,
messages=[{
"role": "user",
"content": prompt,
}],
extra_body={"guided_grammar": simplified_sql_grammar},
)
print("reasoning_content: ", completion.choices[0].message.reasoning_content)
print("content: ", completion.choices[0].message.content)

View File

@ -112,6 +112,7 @@ async def get_guided_decoding_logits_processor(
reasoner = get_reasoner(tokenizer, reasoning_backend)
guided_params = maybe_backend_fallback(guided_params)
# CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend_name == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193

View File

@ -43,7 +43,7 @@ class BaseLogitsProcessor:
def __init__(self, guide: Guide, reasoner: Optional[Reasoner]):
self._guide: Guide = guide
self._reasoner = reasoner
self._reasoner: Optional[Reasoner] = reasoner
# CFGState is used for the FSM state for CFGGuide
self._fsm_state: DefaultDict[int, Union[int,
CFGState]] = defaultdict(int)
@ -54,10 +54,14 @@ class BaseLogitsProcessor:
# Skip the structured logits processing if reasoning is not finished.
# reasoner is not None only when `--enable-reasoning` is set.
if self._reasoner is not None and \
not self._reasoner.is_reasoning_end(
input_ids):
return scores
if self._reasoner is not None:
if not self._reasoner.is_reasoning_end(input_ids):
return scores
else:
# Remove the reasoning tokens from the input_ids
# We need this because our implementation relies on the
# hash of the input_ids to store the FSM state.
input_ids = self._reasoner.extract_content(input_ids)
seq_id = hash(tuple(input_ids))

View File

@ -4,10 +4,13 @@ from __future__ import annotations
from transformers import PreTrainedTokenizer
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.reasoner.deepseek_reasoner import ( # noqa: E501
DeepSeekReasoner)
from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner
logger = init_logger(__name__)
def get_reasoner(tokenizer: PreTrainedTokenizer,
reasoning_backend: str | None) -> Reasoner | None:
@ -17,7 +20,12 @@ def get_reasoner(tokenizer: PreTrainedTokenizer,
elif reasoning_backend == "deepseek_r1":
return DeepSeekReasoner.from_tokenizer(tokenizer)
else:
raise ValueError(f"Unknown reasoning backend '{reasoning_backend}'")
# Raise a warning for unknown reasoning backend and return None
# We cannot raise an error here because some reasoning models
# may not have a corresponding Reasoner class.
logger.warning("Unknown reasoning backend %s for structured outputs ",
reasoning_backend)
return None
__all__ = ["Reasoner", "get_reasoner"]

View File

@ -26,3 +26,13 @@ class DeepSeekReasoner(Reasoner):
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.end_token_id in input_ids
def extract_content(self, input_ids: list[int]) -> list[int]:
"""
Extract the content after the end tokens
"""
if self.end_token_id not in input_ids or \
input_ids.index(self.end_token_id) + 1 == len(input_ids):
return []
else:
return input_ids[input_ids.index(self.end_token_id) + 1:]

View File

@ -17,3 +17,7 @@ class Reasoner(ABC):
@abstractmethod
def is_reasoning_end(self, input_ids: list[int]) -> bool:
pass
@abstractmethod
def extract_content(self, input_ids: list[int]) -> list[int]:
pass

View File

@ -392,7 +392,7 @@ class XGrammarLogitsProcessor:
def clone(self) -> XGrammarLogitsProcessor:
"""Create a new instance with shared compiled grammar
but separate state"""
new_processor = XGrammarLogitsProcessor(self.config)
new_processor = XGrammarLogitsProcessor(self.config, self.reasoner)
# Share the compiled grammar context (immutable after compilation)
new_processor.ctx = self.ctx