[Model] Add Reasoning Parser for Granite Models (#14202)

Signed-off-by: Alex-Brooks <Alex.brooks@ibm.com>
Co-authored-by: Joe Runde <joe@joerun.de>
This commit is contained in:
Alex Brooks 2025-03-26 08:28:07 -06:00 committed by GitHub
parent c091c0a588
commit 1711b929b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 730 additions and 3 deletions

View File

@ -4,7 +4,7 @@
vLLM offers support for reasoning models like [DeepSeek R1](https://huggingface.co/deepseek-ai/DeepSeek-R1), which are designed to generate outputs containing both reasoning steps and final conclusions. vLLM offers support for reasoning models like [DeepSeek R1](https://huggingface.co/deepseek-ai/DeepSeek-R1), which are designed to generate outputs containing both reasoning steps and final conclusions.
Reasoning models return a additional `reasoning_content` field in their outputs, which contains the reasoning steps that led to the final conclusion. This field is not present in the outputs of other models. Reasoning models return an additional `reasoning_content` field in their outputs, which contains the reasoning steps that led to the final conclusion. This field is not present in the outputs of other models.
## Supported Models ## Supported Models
@ -14,6 +14,9 @@ vLLM currently supports the following reasoning models:
|--------------|-------------|------------------|-------------| |--------------|-------------|------------------|-------------|
| [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `guided_json`, `guided_regex` | ❌ | | [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `guided_json`, `guided_regex` | ❌ |
| [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `guided_json`, `guided_regex` | ✅ | | [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `guided_json`, `guided_regex` | ✅ |
| [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` | ❌ | ❌ |
- IBM Granite 3.2 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`.
## Quickstart ## Quickstart
@ -43,6 +46,7 @@ model = models.data[0].id
# Round 1 # Round 1
messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}] messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
# For granite, add: `extra_body={"chat_template_kwargs": {"thinking": True}}`
response = client.chat.completions.create(model=model, messages=messages) response = client.chat.completions.create(model=model, messages=messages)
reasoning_content = response.choices[0].message.reasoning_content reasoning_content = response.choices[0].message.reasoning_content
@ -97,6 +101,7 @@ models = client.models.list()
model = models.data[0].id model = models.data[0].id
messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}] messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
# For granite, add: `extra_body={"chat_template_kwargs": {"thinking": True}}`
stream = client.chat.completions.create(model=model, stream = client.chat.completions.create(model=model,
messages=messages, messages=messages,
stream=True) stream=True)

View File

@ -31,6 +31,7 @@ model = models.data[0].id
# Round 1 # Round 1
messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}] messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
# For granite, add: `extra_body={"chat_template_kwargs": {"thinking": True}}`
response = client.chat.completions.create(model=model, messages=messages) response = client.chat.completions.create(model=model, messages=messages)
reasoning_content = response.choices[0].message.reasoning_content reasoning_content = response.choices[0].message.reasoning_content

View File

@ -38,6 +38,7 @@ models = client.models.list()
model = models.data[0].id model = models.data[0].id
messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}] messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
# For granite, add: `extra_body={"chat_template_kwargs": {"thinking": True}}`
stream = client.chat.completions.create(model=model, stream = client.chat.completions.create(model=model,
messages=messages, messages=messages,
stream=True) stream=True)

View File

@ -0,0 +1,349 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
from transformers import AutoTokenizer
from tests.entrypoints.openai.reasoning_parsers.utils import (
DeltaMessage, run_reasoning_extraction)
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
ReasoningParserManager)
parser_name = "granite"
START_REASONING = "Here is my thought process:"
START_RESPONSE = "Here is my response:"
SIMPLE_REASONING = {
"output":
f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest", #noqa: E501
"reasoning_content": "This is a reasoning section",
"content": "This is the rest",
}
COMPLETE_REASONING = {
"output": f"{START_REASONING}This is a reasoning section{START_RESPONSE}",
"reasoning_content": "This is a reasoning section",
"content": None,
}
NO_REASONING = {
"output": "This is content",
"reasoning_content": None,
"content": "This is content",
}
MULTIPLE_LINES = {
"output":
f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat",
"reasoning_content": "This\nThat",
"content": "This is the rest\nThat",
}
REASONING_WITH_THINK = {
"output":
f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest", #noqa: E501
"reasoning_content": "This is a reasoning section",
"content": "This is the rest",
}
COMPLETE_REASONING_WITH_THINK = {
"output": f"{START_REASONING}This is a reasoning section{START_RESPONSE}",
"reasoning_content": "This is a reasoning section",
"content": None,
}
MULTIPLE_LINES_WITH_THINK = {
"output":
f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat",
"reasoning_content": "This\nThat",
"content": "This is the rest\nThat",
}
TEST_CASES = [
pytest.param(
False,
SIMPLE_REASONING,
id="simple_reasoning",
),
pytest.param(
False,
COMPLETE_REASONING,
id="complete_reasoning",
),
pytest.param(
False,
NO_REASONING,
id="no_reasoning",
),
pytest.param(
False,
MULTIPLE_LINES,
id="multiple_lines",
),
pytest.param(
False,
REASONING_WITH_THINK,
id="reasoning_with_think",
),
pytest.param(
False,
COMPLETE_REASONING_WITH_THINK,
id="complete_reasoning_with_think",
),
pytest.param(
False,
MULTIPLE_LINES_WITH_THINK,
id="multiple_lines_with_think",
),
pytest.param(
True,
SIMPLE_REASONING,
id="simple_reasoning_streaming",
),
pytest.param(
True,
COMPLETE_REASONING,
id="complete_reasoning_streaming",
),
pytest.param(
True,
NO_REASONING,
id="no_reasoning_streaming",
),
pytest.param(
True,
MULTIPLE_LINES,
id="multiple_lines_streaming",
),
pytest.param(
True,
REASONING_WITH_THINK,
id="reasoning_with_think_streaming",
),
pytest.param(
True,
COMPLETE_REASONING_WITH_THINK,
id="complete_reasoning_with_think_streaming",
),
pytest.param(
True,
MULTIPLE_LINES_WITH_THINK,
id="multiple_lines_with_think_streaming",
),
]
# Global tokenizer initialization to avoid repeated loading
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
def test_reasoning(
streaming: bool,
param_dict: dict,
):
output = tokenizer.tokenize(param_dict["output"])
# decode everything to tokens
output_tokens: list[str] = [
tokenizer.convert_tokens_to_string([token]) for token in output
]
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(
parser_name)(tokenizer)
reasoning, content = run_reasoning_extraction(parser,
output_tokens,
streaming=streaming)
assert reasoning == param_dict["reasoning_content"]
assert content == param_dict["content"]
# Additional tests for verifying the correctness of granite streaming; this
# is complicated because granite uses multiple tokens to indicate when thinking
# is starting / when it's starting its response, so skipping special tokens
# is awkward.
### Handling the start of reasoning
STREAMING_1 = {
"previous_text": None,
"current_text": "Here",
"delta_text": "Here",
"reasoning_content": None,
"content": None,
}
# When we fail, we should give what was previously being silenced first
STREAMING_2 = {
"previous_text": "Here is my thought",
"current_text": "Here is my thought failure",
"delta_text": " failure",
"reasoning_content": None,
"content": "Here is my thought failure",
}
# But then after the first one, we should only add the delta text to content
STREAMING_3 = {
"previous_text": "Here wrong",
"current_text": " words",
"delta_text": " Here wrong words",
"reasoning_content": None,
"content": " words",
}
# But then after the first one, we should only add the delta text to content
STREAMING_4 = {
"previous_text": "Here is my thought",
"current_text": "Here is my thought process:",
"delta_text": " process:",
"reasoning_content": None,
"content": None,
}
# Reasoning started successfully; parse reasoning content
STREAMING_5 = {
"previous_text": "Here is my thought process:",
"current_text": "Here is my thought process: foo",
"delta_text": " foo",
"reasoning_content": " foo",
"content": None,
}
# Response special sequence has started, but not finished.
STREAMING_6 = {
"previous_text": "Here is my thought process: foo",
"current_text": "Here is my thought process: foo Here is",
"delta_text": " Here is",
"reasoning_content": " ",
"content": None,
}
# Response special sequence started, but was broken; the reasoning
# content should be the content that was previously unused.
STREAMING_7 = {
"previous_text": "Here is my thought process: foo Here is",
"current_text": "Here is my thought process: foo Here is Here",
"delta_text": " Here",
"reasoning_content": "Here is ",
"content": None,
}
# Response special sequence is ongoing
STREAMING_8 = {
"previous_text": "Here is my thought process: foo Here is my response:",
"current_text": "Here is my thought process: foo Here is my response: bar",
"delta_text": " bar",
"reasoning_content": None,
"content": " bar",
}
# The delta text has everything; we should be able to correctly parse both
STREAMING_9 = {
"previous_text": None,
"current_text": "Here is my thought process: foo Here is my response: bar",
"delta_text": "Here is my thought process: foo Here is my response: bar",
"reasoning_content": " foo ",
"content": " bar",
}
## The Response is ongoing, and the delta mixes reasoning content / content
STREAMING_10 = {
"previous_text": "Here is my thought process: foo",
"current_text":
"Here is my thought process: foo bar Here is my response: baz",
"delta_text": " bar Here is my response: baz",
"reasoning_content": " bar ",
"content": " baz",
}
# The delta text starts a new substring that might be a response special seq
STREAMING_11 = {
"previous_text":
"Here is my thought process: This is a reasoning section ",
"current_text":
"Here is my thought process: This is a reasoning section Here",
"delta_text": "Here",
"reasoning_content": None,
"content": None,
}
# The delta text is finishing the response special seq
STREAMING_12 = {
"previous_text": "Here is my thought process: foo Here is my response",
"current_text": "Here is my thought process: foo Here is my response:",
"delta_text": ":",
"reasoning_content": None,
"content": None,
}
STREAMING_13 = {
"previous_text": "Here is my thought process: foo Here",
"current_text": "Here is my thought process: foo Here was",
"delta_text": " was",
"reasoning_content": "Here was",
"content": None,
}
STREAMING_SUBCASES = [
pytest.param(
STREAMING_1,
id="Starting reasoning special sequence",
),
pytest.param(
STREAMING_2,
id="Unexpected start reasoning sequence",
),
pytest.param(
STREAMING_3,
id="Continuing unexpected start reasoning sequence",
),
pytest.param(
STREAMING_4,
id="Only start reasoning sequence and nothing else",
),
pytest.param(
STREAMING_5,
id="Reasoning content has started",
),
pytest.param(
STREAMING_6,
id="Response special sequence has started",
),
pytest.param(
STREAMING_7,
id="Response special sequence reset",
),
pytest.param(
STREAMING_8,
id="Response text has started",
),
pytest.param(
STREAMING_9,
id="Delta contains everything",
),
pytest.param(
STREAMING_10,
id="Delta contains some reasoning and response",
),
pytest.param(
STREAMING_11,
id="Delta starts response sequence",
),
pytest.param(
STREAMING_12,
id="Delta finishes response sequence",
),
pytest.param(
STREAMING_13,
id="Delta breaks potential responise sequence",
),
]
@pytest.mark.parametrize("param_dict", STREAMING_SUBCASES)
def test_streaming_subcases(param_dict):
# Get all of the token IDs
previous_token_ids = tokenizer.encode(
param_dict["previous_text"]
) if param_dict["previous_text"] is not None else []
current_token_ids = tokenizer.encode(param_dict["current_text"])
delta_token_ids = tokenizer.encode(param_dict["delta_text"])
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(
parser_name)(tokenizer)
response = parser.extract_reasoning_content_streaming(
previous_text=param_dict["previous_text"],
current_text=param_dict["current_text"],
delta_text=param_dict["delta_text"],
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
)
# Streaming currently expects at least one of reasoning content / content,
# so the response should return None in that case.
if param_dict["reasoning_content"] is None and param_dict[
"content"] is None:
assert response is None
else:
assert isinstance(response, DeltaMessage)
assert param_dict["reasoning_content"] == response.reasoning_content
assert param_dict["content"] == response.content

View File

@ -1099,7 +1099,7 @@ class EngineArgs:
parser.add_argument( parser.add_argument(
"--reasoning-parser", "--reasoning-parser",
type=str, type=str,
choices=["deepseek_r1"], choices=["deepseek_r1", "granite"],
default=None, default=None,
help= help=
"Select the reasoning parser depending on the model that you're " "Select the reasoning parser depending on the model that you're "

View File

@ -2,7 +2,11 @@
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
from .granite_reasoning_parser import GraniteReasoningParser
__all__ = [ __all__ = [
"ReasoningParser", "ReasoningParserManager", "DeepSeekR1ReasoningParser" "ReasoningParser",
"ReasoningParserManager",
"DeepSeekR1ReasoningParser",
"GraniteReasoningParser",
] ]

View File

@ -0,0 +1,363 @@
# SPDX-License-Identifier: Apache-2.0
import re
from collections.abc import Sequence
from typing import Optional, Union
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage)
from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import (
ReasoningParser, ReasoningParserManager)
from vllm.logger import init_logger
logger = init_logger(__name__)
@ReasoningParserManager.register_module("granite")
class GraniteReasoningParser(ReasoningParser):
"""
Reasoning parser for IBM Granite.
IBM granite models currently use "Here is my thought process:"
and "Here is my response:" to separate its thinking / response outputs.
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
# NOTE: There have been some observed occurrences of quantized
# instances of the current models using "Here's" instead of "Here is",
# so to be safe, we match on both.
self.think_start_expr = r"(?:Here's|Here is) my thought process:"
self.response_start_expr = r"(?:Here's|Here is) my response:"
self.reasoning_regex = re.compile(
rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)",
re.DOTALL)
self.valid_think_starts = [
"Here's my thought process:", "Here is my thought process:"
]
self.valid_response_starts = [
"Here's my response:", "Here is my response:"
]
# Substrings to match for sequence boundaries on raw text
self.seq_boundary_end = ":"
self.seq_boundary_start = "Here"
# The longest any thinking / start of response message can be
self.longest_think_start = max(
len(think_start) for think_start in self.valid_think_starts)
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
) -> tuple[Optional[str], Optional[str]]:
"""Extract the reasoning content & content sections, respectively.
If the sequence doesn't match what we expect, i.e., the model generates
something else, all content is considered non-reasoning content.
Args:
model_output (str): Output of the model to be parsed.
request (ChatCompletionReqest): Request being processed.
Returns:
tuple[Optional[str], Optional[str]]: Tuple pair containing the
reasoning content and non-reasoning content.
"""
re_match = self.reasoning_regex.findall(model_output)
if not re_match:
return None, model_output
reasoning_content, response_content = re_match[0]
if not response_content:
return reasoning_content, None
return reasoning_content, response_content
def extract_reasoning_content_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]:
"""Extract the reasoning content / content emitted by granite models;
If the sequence doesn't match what we expect, i.e., the model generates
something else, all content is considered non-reasoning content.
NOTE: Granite models do not use a special token to start their reasoning
and response sections; instead they have token sequences, e.g.,
Here is my thought process: Foo Here is my response: Bar
This increases the complexity of correctly handling streams, since we
need to watch for specific sequences and correctly parse them without
dropping content that is potentially overlapping & spanning multiple
delta messages.
Args:
previous_text (str): Previous text outside of this delta message.
current_text (str): Previous text + delta text.
delta_text (str): Text to consider and parse content from.
previous_token_ids (Sequence[int]): Token IDs of previous_text.
current_token_ids (Sequence[int]): Token IDs of current_text.
delta_token_ids (Sequence[int]): Token IDs of delta_text.
Returns:
Union[DeltaMessage, None]
DeltaMessage with either reasoning content or content, or None.
"""
reasoning_content, resp_seq_len, content = self._get_content_sections(
current_text)
# Either we haven't finished the start of the reasoning sequence,
# or the model is generating something unexpected.
if not reasoning_content:
delta_message = self._get_delta_message_with_no_reasoning_bounds(
current_text, delta_text)
# We have a start of reasoning message, but have not yet finished
# the start of response sequence.
elif not content:
delta_message = self._get_delta_message_with_no_response_bounds(
current_text, reasoning_content, delta_text)
# We've finished both the start of reasoning and start of response seq.
else:
# This should never happen since we matched on the response
assert resp_seq_len is not None
delta_message = self._get_delta_message_with_both_bounds(
delta_text, reasoning_content, content, current_text,
resp_seq_len)
if not delta_message.content and not delta_message.reasoning_content:
return None
return delta_message
#### Implementation details of stream parsing for granite models
def _is_reasoning_start_substr(self, text: str) -> bool:
"""Check if a text matches one of the possible start reasoning seqs.
Args:
text (str): Text to check for leading substr.
Returns:
bool: True if any of the possible reasoning start seqs match.
"""
return any(
think_start.startswith(text)
for think_start in self.valid_think_starts)
def _is_response_start_substr(self, text: str) -> bool:
"""Check if a text matches one of the possible start response seqs.
Args:
text (str): Text to check for leading substr.
Returns:
bool: True if any of the possible response start seqs match.
"""
return any(
response_start.startswith(text)
for response_start in self.valid_response_starts)
def _get_delta_message_with_no_reasoning_bounds(
self,
current_text: str,
delta_text: str,
) -> DeltaMessage:
"""Parse the delta message when the current text has not yet completed
its start of reasoning sequence.
Args:
current_text (str): The full previous + delta text.
delta_text (str): Text to consider and parse content from.
Returns:
DeltaMessage: Message containing the parsed content.
"""
prev_longest_length = len(current_text) - len(delta_text)
is_substr = self._is_reasoning_start_substr(current_text)
was_substr = self._is_reasoning_start_substr(
current_text[:prev_longest_length])
# Check if we just generated something NOT in the special token seq;
# if so, add everything that we previously skipped with this delta
# message and append everything to content in the future.
if was_substr and not is_substr:
return DeltaMessage(
reasoning_content=None,
content=current_text,
)
if is_substr:
# Might still be in the special token sequence; return nothing
return DeltaMessage(reasoning_content=None, content=None)
# Otherwise the sequence has already been broken and we already
# corrected; just return the delta text as normal content.
return DeltaMessage(reasoning_content=None, content=delta_text)
def _get_delta_message_with_no_response_bounds(
self,
current_text: str,
reasoning_content: str,
delta_text: str,
) -> DeltaMessage:
"""Parse the delta message when the current text has both reasoning
content with no (response) content. NOTE that we may have overlapping
tokens with the start of reasoning / start of response sequences on
either side of the delta text.
Args:
current_text (str): The full previous + delta text.
reasoning_content (str): reasoning content from current_text.
delta_text (str): Text to consider and parse content from.
Returns:
DeltaMessage: Message containing the parsed content.
"""
# If we have no reasoning content or explicitly end with the start of
# response sequence, we are in transition to the response; need to be
# careful here, since the final token (:) will match the reasoning
# content and fully parse it out; we should not pass the : back.
ends_with_start_response_seq = any(
current_text.endswith(response_start)
for response_start in self.valid_response_starts)
if reasoning_content is None or ends_with_start_response_seq:
return DeltaMessage(reasoning_content=None, content=None)
# Consider previous / current text only within context of the reasoning
previous_text = reasoning_content[:-len(delta_text)]
current_text = reasoning_content
# We need to be careful about adding unfinished response sequences;
# Find the place at which we MIGHT be starting a response sequence
prev_idx = previous_text.rfind(self.seq_boundary_start)
delta_idx = delta_text.rfind(self.seq_boundary_start)
# Check the state of potential start of response substring matches.
prev_was_substr = self._is_response_start_substr(
previous_text[prev_idx:]) if prev_idx >= 0 else False
delta_continues_substr = self._is_response_start_substr(
current_text[prev_idx:]) if prev_idx >= 0 else False
delta_new_substr = self._is_response_start_substr(
delta_text[delta_idx:]) if delta_idx >= 0 else False
# Delta only contains potential continued response sequence text.
if delta_continues_substr:
return DeltaMessage(reasoning_content=None, content=None)
if not prev_was_substr:
# Delta may be starting a new response seq but has other text too.
if delta_new_substr:
return DeltaMessage(reasoning_content=delta_text[:delta_idx],
content=None)
# Normal case for most reasoning text (no potential special seqs).
return DeltaMessage(reasoning_content=delta_text, content=None)
# The substring that previously seemed to be a potential response
# seq wasn't one; we need to add the content to the delta message,
# and also slice off the potential response sequence
elif delta_new_substr:
reasoning_content = previous_text[
prev_idx:] + delta_text[:delta_idx]
return DeltaMessage(reasoning_content=reasoning_content,
content=None)
# No new substring yet, and we broke our old one; take the whole delta
return DeltaMessage(
reasoning_content=previous_text[prev_idx:] + delta_text,
content=None,
)
def _get_delta_message_with_both_bounds(
self,
delta_text: str,
reasoning_content: str,
response_content: str,
current_text: str,
response_seq_len: int,
) -> DeltaMessage:
"""Parse the delta message when the current text has both reasoning
content and normal (response) content.
Args:
delta_text (str): Text to consider and parse content from.
reasoning_content (str): reasoning content from current_text.
response_content (str): response content from current_text.
current_text (str): The full previous + delta text.
response_seq_len(str): Len of the complete response sequence used.
Returns:
DeltaMessage: Message containing the parsed content.
"""
# Always have content; take length to the end
delta_content = delta_text[-len(response_content):]
reasoning_end_idx = len(delta_text) - (len(response_content) +
response_seq_len)
if reasoning_end_idx < 0:
delta_reasoning_content = None
else:
# Get the starting offset
start_reasoning_content_idx = len(
reasoning_content) + response_seq_len + len(
response_content) - 1
delta_offset = len(current_text) - len(delta_text)
start_offset = start_reasoning_content_idx - delta_offset
if start_offset < 0:
start_offset = 0
delta_reasoning_content = delta_text[
start_offset:reasoning_end_idx]
return DeltaMessage(
reasoning_content=delta_reasoning_content,
content=delta_content,
)
def _get_content_sections(
self, current_text: str
) -> tuple[Optional[str], Optional[int], Optional[str]]:
"""Parse the text to extract the reasoning content / content
if we have them.
Args:
current_text (str): The full previous + delta text.
Returns:
tuple[Optional[str], Optional[int], Optional[str]]: Tuple of len 3
containing the reasoning content, the length of the response seq
(if there is one) and the non-reasoning content.
"""
current_chunk_start = 0
start_reasoning_content = None
parsed_content = False
delimiter_idxs = [
idx for idx, char in enumerate(current_text)
if char == self.seq_boundary_end
]
for current_chunk_end in delimiter_idxs:
current_chunk = current_text[current_chunk_start:current_chunk_end]
# Check to see if the start of reasoning seq if complete
if start_reasoning_content is None:
for think_start in self.valid_think_starts:
if current_chunk == think_start[:-1]:
start_reasoning_content = current_chunk_end + 1
current_chunk_start = current_chunk_end + 1
break
# Check to see if the start of response seq if complete
elif not parsed_content:
for response_start in self.valid_response_starts:
if current_chunk[-len(response_start) +
1:] == response_start[:-1]:
# Mark end of reasoning and start response content
# after the start of response sequence.
end_reasoning_content = current_chunk_end - len(
response_start)
reasoning_content = current_text[
start_reasoning_content:end_reasoning_content]
response_content = current_text[current_chunk_end + 1:]
return reasoning_content, len(
response_start), response_content
if start_reasoning_content and not parsed_content:
return current_text[start_reasoning_content:], None, None
return None, None, None

View File

@ -19,6 +19,10 @@ def get_reasoner(tokenizer: PreTrainedTokenizer,
return None return None
elif reasoning_backend == "deepseek_r1": elif reasoning_backend == "deepseek_r1":
return DeepSeekReasoner.from_tokenizer(tokenizer) return DeepSeekReasoner.from_tokenizer(tokenizer)
elif reasoning_backend == "granite":
logger.warning(
"Granite reasoner not yet implemented for structured outputs")
return None
else: else:
# Raise a warning for unknown reasoning backend and return None # Raise a warning for unknown reasoning backend and return None
# We cannot raise an error here because some reasoning models # We cannot raise an error here because some reasoning models