[Bugfix][Structured Output] Support outlines engine with reasoning outputs for DeepSeek R1 (#14114)
This commit is contained in:
parent
abcc61e0af
commit
f5f7f00cd9
@ -10,7 +10,9 @@ Reasoning models return a additional `reasoning_content` field in their outputs,
|
||||
|
||||
vLLM currently supports the following reasoning models:
|
||||
|
||||
- [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) (`deepseek_r1`, which looks for `<think> ... </think>`)
|
||||
| Model Series | Parser Name | Structured Output Support |
|
||||
|--------------|-------------|------------------|
|
||||
| [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `guided_json`, `guided_regex` |
|
||||
|
||||
## Quickstart
|
||||
|
||||
@ -78,11 +80,51 @@ Streaming chat completions are also supported for reasoning models. The `reasoni
|
||||
|
||||
Please note that it is not compatible with the OpenAI Python client library. You can use the `requests` library to make streaming requests. You could checkout the [example](https://github.com/vllm-project/vllm/blob/main/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py).
|
||||
|
||||
## Structured output
|
||||
|
||||
The reasoning content is also available in the structured output. The structured output engine like `xgrammar` will use the reasoning content to generate structured output.
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = "http://localhost:8000/v1"
|
||||
|
||||
client = OpenAI(
|
||||
api_key=openai_api_key,
|
||||
base_url=openai_api_base,
|
||||
)
|
||||
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
|
||||
class People(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
|
||||
json_schema = People.model_json_schema()
|
||||
|
||||
prompt = ("Generate a JSON with the name and age of one random person.")
|
||||
completion = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}],
|
||||
extra_body={"guided_json": json_schema},
|
||||
)
|
||||
print("reasoning_content: ", completion.choices[0].message.reasoning_content)
|
||||
print("content: ", completion.choices[0].message.content)
|
||||
```
|
||||
|
||||
## Limitations
|
||||
|
||||
- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`).
|
||||
- It is not compatible with [`tool_calling`](#tool_calling).
|
||||
- The reasoning content is not available for all models. Check the model's documentation to see if it supports reasoning.
|
||||
|
||||
## How to support a new reasoning model
|
||||
|
||||
@ -166,9 +208,10 @@ class DeepSeekReasoner(Reasoner):
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
return self.end_token_id in input_ids
|
||||
...
|
||||
```
|
||||
|
||||
The structured output engine like xgrammar will use `end_token_id` to check if the reasoning content is present in the model output and skip the structured output if it is the case.
|
||||
The structured output engine like `xgrammar` will use `end_token_id` to check if the reasoning content is present in the model output and skip the structured output if it is the case.
|
||||
|
||||
Finally, you can enable reasoning for the model by using the `--enable-reasoning` and `--reasoning-parser` flags.
|
||||
|
||||
|
@ -33,6 +33,42 @@ client = OpenAI(
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
# Guided decoding by Regex
|
||||
prompt = ("What is the capital of France?")
|
||||
|
||||
completion = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}],
|
||||
extra_body={
|
||||
"guided_regex": "(Paris|London)",
|
||||
},
|
||||
)
|
||||
print("reasoning_content: ", completion.choices[0].message.reasoning_content)
|
||||
print("content: ", completion.choices[0].message.content)
|
||||
|
||||
|
||||
class People(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
|
||||
json_schema = People.model_json_schema()
|
||||
|
||||
prompt = ("Generate a JSON with the name and age of one random person.")
|
||||
completion = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}],
|
||||
extra_body={"guided_json": json_schema},
|
||||
)
|
||||
print("reasoning_content: ", completion.choices[0].message.reasoning_content)
|
||||
print("content: ", completion.choices[0].message.content)
|
||||
|
||||
|
||||
# Guided decoding by JSON using Pydantic schema
|
||||
class CarType(str, Enum):
|
||||
@ -51,7 +87,7 @@ class CarDescription(BaseModel):
|
||||
json_schema = CarDescription.model_json_schema()
|
||||
|
||||
prompt = ("Generate a JSON with the brand, model and car_type of"
|
||||
"the most iconic car from the 90's, think in 100 tokens")
|
||||
"the most iconic car from the 90's")
|
||||
completion = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{
|
||||
@ -60,5 +96,34 @@ completion = client.chat.completions.create(
|
||||
}],
|
||||
extra_body={"guided_json": json_schema},
|
||||
)
|
||||
print("content", completion.choices[0].message.content)
|
||||
print("reasoning_content: ", completion.choices[0].message.reasoning_content)
|
||||
print("content: ", completion.choices[0].message.content)
|
||||
|
||||
# Guided decoding by Grammar
|
||||
simplified_sql_grammar = """
|
||||
?start: select_statement
|
||||
|
||||
?select_statement: "SELECT " column_list " FROM " table_name
|
||||
|
||||
?column_list: column_name ("," column_name)*
|
||||
|
||||
?table_name: identifier
|
||||
|
||||
?column_name: identifier
|
||||
|
||||
?identifier: /[a-zA-Z_][a-zA-Z0-9_]*/
|
||||
"""
|
||||
|
||||
# This may be very slow https://github.com/vllm-project/vllm/issues/12122
|
||||
prompt = ("Generate an SQL query to show the 'username' and 'email'"
|
||||
"from the 'users' table.")
|
||||
completion = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}],
|
||||
extra_body={"guided_grammar": simplified_sql_grammar},
|
||||
)
|
||||
print("reasoning_content: ", completion.choices[0].message.reasoning_content)
|
||||
print("content: ", completion.choices[0].message.content)
|
||||
|
@ -112,6 +112,7 @@ async def get_guided_decoding_logits_processor(
|
||||
reasoner = get_reasoner(tokenizer, reasoning_backend)
|
||||
|
||||
guided_params = maybe_backend_fallback(guided_params)
|
||||
|
||||
# CFG grammar not supported by LMFE, so we use outlines instead
|
||||
if guided_params.backend_name == 'outlines':
|
||||
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
|
||||
|
@ -43,7 +43,7 @@ class BaseLogitsProcessor:
|
||||
|
||||
def __init__(self, guide: Guide, reasoner: Optional[Reasoner]):
|
||||
self._guide: Guide = guide
|
||||
self._reasoner = reasoner
|
||||
self._reasoner: Optional[Reasoner] = reasoner
|
||||
# CFGState is used for the FSM state for CFGGuide
|
||||
self._fsm_state: DefaultDict[int, Union[int,
|
||||
CFGState]] = defaultdict(int)
|
||||
@ -54,10 +54,14 @@ class BaseLogitsProcessor:
|
||||
|
||||
# Skip the structured logits processing if reasoning is not finished.
|
||||
# reasoner is not None only when `--enable-reasoning` is set.
|
||||
if self._reasoner is not None and \
|
||||
not self._reasoner.is_reasoning_end(
|
||||
input_ids):
|
||||
return scores
|
||||
if self._reasoner is not None:
|
||||
if not self._reasoner.is_reasoning_end(input_ids):
|
||||
return scores
|
||||
else:
|
||||
# Remove the reasoning tokens from the input_ids
|
||||
# We need this because our implementation relies on the
|
||||
# hash of the input_ids to store the FSM state.
|
||||
input_ids = self._reasoner.extract_content(input_ids)
|
||||
|
||||
seq_id = hash(tuple(input_ids))
|
||||
|
||||
|
@ -4,10 +4,13 @@ from __future__ import annotations
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.guided_decoding.reasoner.deepseek_reasoner import ( # noqa: E501
|
||||
DeepSeekReasoner)
|
||||
from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def get_reasoner(tokenizer: PreTrainedTokenizer,
|
||||
reasoning_backend: str | None) -> Reasoner | None:
|
||||
@ -17,7 +20,12 @@ def get_reasoner(tokenizer: PreTrainedTokenizer,
|
||||
elif reasoning_backend == "deepseek_r1":
|
||||
return DeepSeekReasoner.from_tokenizer(tokenizer)
|
||||
else:
|
||||
raise ValueError(f"Unknown reasoning backend '{reasoning_backend}'")
|
||||
# Raise a warning for unknown reasoning backend and return None
|
||||
# We cannot raise an error here because some reasoning models
|
||||
# may not have a corresponding Reasoner class.
|
||||
logger.warning("Unknown reasoning backend %s for structured outputs ",
|
||||
reasoning_backend)
|
||||
return None
|
||||
|
||||
|
||||
__all__ = ["Reasoner", "get_reasoner"]
|
||||
|
@ -26,3 +26,13 @@ class DeepSeekReasoner(Reasoner):
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
return self.end_token_id in input_ids
|
||||
|
||||
def extract_content(self, input_ids: list[int]) -> list[int]:
|
||||
"""
|
||||
Extract the content after the end tokens
|
||||
"""
|
||||
if self.end_token_id not in input_ids or \
|
||||
input_ids.index(self.end_token_id) + 1 == len(input_ids):
|
||||
return []
|
||||
else:
|
||||
return input_ids[input_ids.index(self.end_token_id) + 1:]
|
||||
|
@ -17,3 +17,7 @@ class Reasoner(ABC):
|
||||
@abstractmethod
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def extract_content(self, input_ids: list[int]) -> list[int]:
|
||||
pass
|
||||
|
@ -392,7 +392,7 @@ class XGrammarLogitsProcessor:
|
||||
def clone(self) -> XGrammarLogitsProcessor:
|
||||
"""Create a new instance with shared compiled grammar
|
||||
but separate state"""
|
||||
new_processor = XGrammarLogitsProcessor(self.config)
|
||||
new_processor = XGrammarLogitsProcessor(self.config, self.reasoner)
|
||||
|
||||
# Share the compiled grammar context (immutable after compilation)
|
||||
new_processor.ctx = self.ctx
|
||||
|
Loading…
x
Reference in New Issue
Block a user