[Model] Add Reasoning Parser for Granite Models (#14202)
Signed-off-by: Alex-Brooks <Alex.brooks@ibm.com> Co-authored-by: Joe Runde <joe@joerun.de>
This commit is contained in:
parent
c091c0a588
commit
1711b929b6
@ -4,7 +4,7 @@
|
||||
|
||||
vLLM offers support for reasoning models like [DeepSeek R1](https://huggingface.co/deepseek-ai/DeepSeek-R1), which are designed to generate outputs containing both reasoning steps and final conclusions.
|
||||
|
||||
Reasoning models return a additional `reasoning_content` field in their outputs, which contains the reasoning steps that led to the final conclusion. This field is not present in the outputs of other models.
|
||||
Reasoning models return an additional `reasoning_content` field in their outputs, which contains the reasoning steps that led to the final conclusion. This field is not present in the outputs of other models.
|
||||
|
||||
## Supported Models
|
||||
|
||||
@ -14,6 +14,9 @@ vLLM currently supports the following reasoning models:
|
||||
|--------------|-------------|------------------|-------------|
|
||||
| [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `guided_json`, `guided_regex` | ❌ |
|
||||
| [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `guided_json`, `guided_regex` | ✅ |
|
||||
| [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` | ❌ | ❌ |
|
||||
|
||||
- IBM Granite 3.2 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`.
|
||||
|
||||
## Quickstart
|
||||
|
||||
@ -43,6 +46,7 @@ model = models.data[0].id
|
||||
|
||||
# Round 1
|
||||
messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
|
||||
# For granite, add: `extra_body={"chat_template_kwargs": {"thinking": True}}`
|
||||
response = client.chat.completions.create(model=model, messages=messages)
|
||||
|
||||
reasoning_content = response.choices[0].message.reasoning_content
|
||||
@ -97,6 +101,7 @@ models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
|
||||
# For granite, add: `extra_body={"chat_template_kwargs": {"thinking": True}}`
|
||||
stream = client.chat.completions.create(model=model,
|
||||
messages=messages,
|
||||
stream=True)
|
||||
|
@ -31,6 +31,7 @@ model = models.data[0].id
|
||||
|
||||
# Round 1
|
||||
messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
|
||||
# For granite, add: `extra_body={"chat_template_kwargs": {"thinking": True}}`
|
||||
response = client.chat.completions.create(model=model, messages=messages)
|
||||
|
||||
reasoning_content = response.choices[0].message.reasoning_content
|
||||
|
@ -38,6 +38,7 @@ models = client.models.list()
|
||||
model = models.data[0].id
|
||||
|
||||
messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}]
|
||||
# For granite, add: `extra_body={"chat_template_kwargs": {"thinking": True}}`
|
||||
stream = client.chat.completions.create(model=model,
|
||||
messages=messages,
|
||||
stream=True)
|
||||
|
@ -0,0 +1,349 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.entrypoints.openai.reasoning_parsers.utils import (
|
||||
DeltaMessage, run_reasoning_extraction)
|
||||
from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser,
|
||||
ReasoningParserManager)
|
||||
|
||||
parser_name = "granite"
|
||||
START_REASONING = "Here is my thought process:"
|
||||
START_RESPONSE = "Here is my response:"
|
||||
|
||||
SIMPLE_REASONING = {
|
||||
"output":
|
||||
f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest", #noqa: E501
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": "This is the rest",
|
||||
}
|
||||
COMPLETE_REASONING = {
|
||||
"output": f"{START_REASONING}This is a reasoning section{START_RESPONSE}",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": None,
|
||||
}
|
||||
NO_REASONING = {
|
||||
"output": "This is content",
|
||||
"reasoning_content": None,
|
||||
"content": "This is content",
|
||||
}
|
||||
MULTIPLE_LINES = {
|
||||
"output":
|
||||
f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat",
|
||||
"reasoning_content": "This\nThat",
|
||||
"content": "This is the rest\nThat",
|
||||
}
|
||||
REASONING_WITH_THINK = {
|
||||
"output":
|
||||
f"{START_REASONING}This is a reasoning section{START_RESPONSE}This is the rest", #noqa: E501
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": "This is the rest",
|
||||
}
|
||||
COMPLETE_REASONING_WITH_THINK = {
|
||||
"output": f"{START_REASONING}This is a reasoning section{START_RESPONSE}",
|
||||
"reasoning_content": "This is a reasoning section",
|
||||
"content": None,
|
||||
}
|
||||
MULTIPLE_LINES_WITH_THINK = {
|
||||
"output":
|
||||
f"{START_REASONING}This\nThat{START_RESPONSE}This is the rest\nThat",
|
||||
"reasoning_content": "This\nThat",
|
||||
"content": "This is the rest\nThat",
|
||||
}
|
||||
|
||||
TEST_CASES = [
|
||||
pytest.param(
|
||||
False,
|
||||
SIMPLE_REASONING,
|
||||
id="simple_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
COMPLETE_REASONING,
|
||||
id="complete_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
NO_REASONING,
|
||||
id="no_reasoning",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
MULTIPLE_LINES,
|
||||
id="multiple_lines",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
REASONING_WITH_THINK,
|
||||
id="reasoning_with_think",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
COMPLETE_REASONING_WITH_THINK,
|
||||
id="complete_reasoning_with_think",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
MULTIPLE_LINES_WITH_THINK,
|
||||
id="multiple_lines_with_think",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
SIMPLE_REASONING,
|
||||
id="simple_reasoning_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
COMPLETE_REASONING,
|
||||
id="complete_reasoning_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
NO_REASONING,
|
||||
id="no_reasoning_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
MULTIPLE_LINES,
|
||||
id="multiple_lines_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
REASONING_WITH_THINK,
|
||||
id="reasoning_with_think_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
COMPLETE_REASONING_WITH_THINK,
|
||||
id="complete_reasoning_with_think_streaming",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
MULTIPLE_LINES_WITH_THINK,
|
||||
id="multiple_lines_with_think_streaming",
|
||||
),
|
||||
]
|
||||
|
||||
# Global tokenizer initialization to avoid repeated loading
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
|
||||
def test_reasoning(
|
||||
streaming: bool,
|
||||
param_dict: dict,
|
||||
):
|
||||
output = tokenizer.tokenize(param_dict["output"])
|
||||
# decode everything to tokens
|
||||
output_tokens: list[str] = [
|
||||
tokenizer.convert_tokens_to_string([token]) for token in output
|
||||
]
|
||||
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(
|
||||
parser_name)(tokenizer)
|
||||
|
||||
reasoning, content = run_reasoning_extraction(parser,
|
||||
output_tokens,
|
||||
streaming=streaming)
|
||||
|
||||
assert reasoning == param_dict["reasoning_content"]
|
||||
assert content == param_dict["content"]
|
||||
|
||||
|
||||
# Additional tests for verifying the correctness of granite streaming; this
|
||||
# is complicated because granite uses multiple tokens to indicate when thinking
|
||||
# is starting / when it's starting its response, so skipping special tokens
|
||||
# is awkward.
|
||||
|
||||
### Handling the start of reasoning
|
||||
STREAMING_1 = {
|
||||
"previous_text": None,
|
||||
"current_text": "Here",
|
||||
"delta_text": "Here",
|
||||
"reasoning_content": None,
|
||||
"content": None,
|
||||
}
|
||||
# When we fail, we should give what was previously being silenced first
|
||||
STREAMING_2 = {
|
||||
"previous_text": "Here is my thought",
|
||||
"current_text": "Here is my thought failure",
|
||||
"delta_text": " failure",
|
||||
"reasoning_content": None,
|
||||
"content": "Here is my thought failure",
|
||||
}
|
||||
# But then after the first one, we should only add the delta text to content
|
||||
STREAMING_3 = {
|
||||
"previous_text": "Here wrong",
|
||||
"current_text": " words",
|
||||
"delta_text": " Here wrong words",
|
||||
"reasoning_content": None,
|
||||
"content": " words",
|
||||
}
|
||||
# But then after the first one, we should only add the delta text to content
|
||||
STREAMING_4 = {
|
||||
"previous_text": "Here is my thought",
|
||||
"current_text": "Here is my thought process:",
|
||||
"delta_text": " process:",
|
||||
"reasoning_content": None,
|
||||
"content": None,
|
||||
}
|
||||
# Reasoning started successfully; parse reasoning content
|
||||
STREAMING_5 = {
|
||||
"previous_text": "Here is my thought process:",
|
||||
"current_text": "Here is my thought process: foo",
|
||||
"delta_text": " foo",
|
||||
"reasoning_content": " foo",
|
||||
"content": None,
|
||||
}
|
||||
# Response special sequence has started, but not finished.
|
||||
STREAMING_6 = {
|
||||
"previous_text": "Here is my thought process: foo",
|
||||
"current_text": "Here is my thought process: foo Here is",
|
||||
"delta_text": " Here is",
|
||||
"reasoning_content": " ",
|
||||
"content": None,
|
||||
}
|
||||
# Response special sequence started, but was broken; the reasoning
|
||||
# content should be the content that was previously unused.
|
||||
STREAMING_7 = {
|
||||
"previous_text": "Here is my thought process: foo Here is",
|
||||
"current_text": "Here is my thought process: foo Here is Here",
|
||||
"delta_text": " Here",
|
||||
"reasoning_content": "Here is ",
|
||||
"content": None,
|
||||
}
|
||||
# Response special sequence is ongoing
|
||||
STREAMING_8 = {
|
||||
"previous_text": "Here is my thought process: foo Here is my response:",
|
||||
"current_text": "Here is my thought process: foo Here is my response: bar",
|
||||
"delta_text": " bar",
|
||||
"reasoning_content": None,
|
||||
"content": " bar",
|
||||
}
|
||||
# The delta text has everything; we should be able to correctly parse both
|
||||
STREAMING_9 = {
|
||||
"previous_text": None,
|
||||
"current_text": "Here is my thought process: foo Here is my response: bar",
|
||||
"delta_text": "Here is my thought process: foo Here is my response: bar",
|
||||
"reasoning_content": " foo ",
|
||||
"content": " bar",
|
||||
}
|
||||
## The Response is ongoing, and the delta mixes reasoning content / content
|
||||
STREAMING_10 = {
|
||||
"previous_text": "Here is my thought process: foo",
|
||||
"current_text":
|
||||
"Here is my thought process: foo bar Here is my response: baz",
|
||||
"delta_text": " bar Here is my response: baz",
|
||||
"reasoning_content": " bar ",
|
||||
"content": " baz",
|
||||
}
|
||||
# The delta text starts a new substring that might be a response special seq
|
||||
STREAMING_11 = {
|
||||
"previous_text":
|
||||
"Here is my thought process: This is a reasoning section ",
|
||||
"current_text":
|
||||
"Here is my thought process: This is a reasoning section Here",
|
||||
"delta_text": "Here",
|
||||
"reasoning_content": None,
|
||||
"content": None,
|
||||
}
|
||||
# The delta text is finishing the response special seq
|
||||
STREAMING_12 = {
|
||||
"previous_text": "Here is my thought process: foo Here is my response",
|
||||
"current_text": "Here is my thought process: foo Here is my response:",
|
||||
"delta_text": ":",
|
||||
"reasoning_content": None,
|
||||
"content": None,
|
||||
}
|
||||
STREAMING_13 = {
|
||||
"previous_text": "Here is my thought process: foo Here",
|
||||
"current_text": "Here is my thought process: foo Here was",
|
||||
"delta_text": " was",
|
||||
"reasoning_content": "Here was",
|
||||
"content": None,
|
||||
}
|
||||
|
||||
STREAMING_SUBCASES = [
|
||||
pytest.param(
|
||||
STREAMING_1,
|
||||
id="Starting reasoning special sequence",
|
||||
),
|
||||
pytest.param(
|
||||
STREAMING_2,
|
||||
id="Unexpected start reasoning sequence",
|
||||
),
|
||||
pytest.param(
|
||||
STREAMING_3,
|
||||
id="Continuing unexpected start reasoning sequence",
|
||||
),
|
||||
pytest.param(
|
||||
STREAMING_4,
|
||||
id="Only start reasoning sequence and nothing else",
|
||||
),
|
||||
pytest.param(
|
||||
STREAMING_5,
|
||||
id="Reasoning content has started",
|
||||
),
|
||||
pytest.param(
|
||||
STREAMING_6,
|
||||
id="Response special sequence has started",
|
||||
),
|
||||
pytest.param(
|
||||
STREAMING_7,
|
||||
id="Response special sequence reset",
|
||||
),
|
||||
pytest.param(
|
||||
STREAMING_8,
|
||||
id="Response text has started",
|
||||
),
|
||||
pytest.param(
|
||||
STREAMING_9,
|
||||
id="Delta contains everything",
|
||||
),
|
||||
pytest.param(
|
||||
STREAMING_10,
|
||||
id="Delta contains some reasoning and response",
|
||||
),
|
||||
pytest.param(
|
||||
STREAMING_11,
|
||||
id="Delta starts response sequence",
|
||||
),
|
||||
pytest.param(
|
||||
STREAMING_12,
|
||||
id="Delta finishes response sequence",
|
||||
),
|
||||
pytest.param(
|
||||
STREAMING_13,
|
||||
id="Delta breaks potential responise sequence",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("param_dict", STREAMING_SUBCASES)
|
||||
def test_streaming_subcases(param_dict):
|
||||
# Get all of the token IDs
|
||||
previous_token_ids = tokenizer.encode(
|
||||
param_dict["previous_text"]
|
||||
) if param_dict["previous_text"] is not None else []
|
||||
current_token_ids = tokenizer.encode(param_dict["current_text"])
|
||||
delta_token_ids = tokenizer.encode(param_dict["delta_text"])
|
||||
|
||||
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(
|
||||
parser_name)(tokenizer)
|
||||
|
||||
response = parser.extract_reasoning_content_streaming(
|
||||
previous_text=param_dict["previous_text"],
|
||||
current_text=param_dict["current_text"],
|
||||
delta_text=param_dict["delta_text"],
|
||||
previous_token_ids=previous_token_ids,
|
||||
current_token_ids=current_token_ids,
|
||||
delta_token_ids=delta_token_ids,
|
||||
)
|
||||
# Streaming currently expects at least one of reasoning content / content,
|
||||
# so the response should return None in that case.
|
||||
if param_dict["reasoning_content"] is None and param_dict[
|
||||
"content"] is None:
|
||||
assert response is None
|
||||
else:
|
||||
assert isinstance(response, DeltaMessage)
|
||||
assert param_dict["reasoning_content"] == response.reasoning_content
|
||||
assert param_dict["content"] == response.content
|
@ -1099,7 +1099,7 @@ class EngineArgs:
|
||||
parser.add_argument(
|
||||
"--reasoning-parser",
|
||||
type=str,
|
||||
choices=["deepseek_r1"],
|
||||
choices=["deepseek_r1", "granite"],
|
||||
default=None,
|
||||
help=
|
||||
"Select the reasoning parser depending on the model that you're "
|
||||
|
@ -2,7 +2,11 @@
|
||||
|
||||
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
|
||||
from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
|
||||
from .granite_reasoning_parser import GraniteReasoningParser
|
||||
|
||||
__all__ = [
|
||||
"ReasoningParser", "ReasoningParserManager", "DeepSeekR1ReasoningParser"
|
||||
"ReasoningParser",
|
||||
"ReasoningParserManager",
|
||||
"DeepSeekR1ReasoningParser",
|
||||
"GraniteReasoningParser",
|
||||
]
|
||||
|
@ -0,0 +1,363 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage)
|
||||
from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import (
|
||||
ReasoningParser, ReasoningParserManager)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ReasoningParserManager.register_module("granite")
|
||||
class GraniteReasoningParser(ReasoningParser):
|
||||
"""
|
||||
Reasoning parser for IBM Granite.
|
||||
|
||||
IBM granite models currently use "Here is my thought process:"
|
||||
and "Here is my response:" to separate its thinking / response outputs.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
# NOTE: There have been some observed occurrences of quantized
|
||||
# instances of the current models using "Here's" instead of "Here is",
|
||||
# so to be safe, we match on both.
|
||||
self.think_start_expr = r"(?:Here's|Here is) my thought process:"
|
||||
self.response_start_expr = r"(?:Here's|Here is) my response:"
|
||||
|
||||
self.reasoning_regex = re.compile(
|
||||
rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)",
|
||||
re.DOTALL)
|
||||
|
||||
self.valid_think_starts = [
|
||||
"Here's my thought process:", "Here is my thought process:"
|
||||
]
|
||||
self.valid_response_starts = [
|
||||
"Here's my response:", "Here is my response:"
|
||||
]
|
||||
|
||||
# Substrings to match for sequence boundaries on raw text
|
||||
self.seq_boundary_end = ":"
|
||||
self.seq_boundary_start = "Here"
|
||||
|
||||
# The longest any thinking / start of response message can be
|
||||
self.longest_think_start = max(
|
||||
len(think_start) for think_start in self.valid_think_starts)
|
||||
|
||||
def extract_reasoning_content(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""Extract the reasoning content & content sections, respectively.
|
||||
If the sequence doesn't match what we expect, i.e., the model generates
|
||||
something else, all content is considered non-reasoning content.
|
||||
|
||||
Args:
|
||||
model_output (str): Output of the model to be parsed.
|
||||
request (ChatCompletionReqest): Request being processed.
|
||||
|
||||
Returns:
|
||||
tuple[Optional[str], Optional[str]]: Tuple pair containing the
|
||||
reasoning content and non-reasoning content.
|
||||
"""
|
||||
re_match = self.reasoning_regex.findall(model_output)
|
||||
if not re_match:
|
||||
return None, model_output
|
||||
reasoning_content, response_content = re_match[0]
|
||||
if not response_content:
|
||||
return reasoning_content, None
|
||||
return reasoning_content, response_content
|
||||
|
||||
def extract_reasoning_content_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> Union[DeltaMessage, None]:
|
||||
"""Extract the reasoning content / content emitted by granite models;
|
||||
If the sequence doesn't match what we expect, i.e., the model generates
|
||||
something else, all content is considered non-reasoning content.
|
||||
|
||||
NOTE: Granite models do not use a special token to start their reasoning
|
||||
and response sections; instead they have token sequences, e.g.,
|
||||
|
||||
Here is my thought process: Foo Here is my response: Bar
|
||||
|
||||
This increases the complexity of correctly handling streams, since we
|
||||
need to watch for specific sequences and correctly parse them without
|
||||
dropping content that is potentially overlapping & spanning multiple
|
||||
delta messages.
|
||||
|
||||
Args:
|
||||
previous_text (str): Previous text outside of this delta message.
|
||||
current_text (str): Previous text + delta text.
|
||||
delta_text (str): Text to consider and parse content from.
|
||||
previous_token_ids (Sequence[int]): Token IDs of previous_text.
|
||||
current_token_ids (Sequence[int]): Token IDs of current_text.
|
||||
delta_token_ids (Sequence[int]): Token IDs of delta_text.
|
||||
|
||||
Returns:
|
||||
Union[DeltaMessage, None]
|
||||
DeltaMessage with either reasoning content or content, or None.
|
||||
"""
|
||||
reasoning_content, resp_seq_len, content = self._get_content_sections(
|
||||
current_text)
|
||||
# Either we haven't finished the start of the reasoning sequence,
|
||||
# or the model is generating something unexpected.
|
||||
if not reasoning_content:
|
||||
delta_message = self._get_delta_message_with_no_reasoning_bounds(
|
||||
current_text, delta_text)
|
||||
# We have a start of reasoning message, but have not yet finished
|
||||
# the start of response sequence.
|
||||
elif not content:
|
||||
delta_message = self._get_delta_message_with_no_response_bounds(
|
||||
current_text, reasoning_content, delta_text)
|
||||
# We've finished both the start of reasoning and start of response seq.
|
||||
else:
|
||||
# This should never happen since we matched on the response
|
||||
assert resp_seq_len is not None
|
||||
delta_message = self._get_delta_message_with_both_bounds(
|
||||
delta_text, reasoning_content, content, current_text,
|
||||
resp_seq_len)
|
||||
if not delta_message.content and not delta_message.reasoning_content:
|
||||
return None
|
||||
return delta_message
|
||||
|
||||
#### Implementation details of stream parsing for granite models
|
||||
def _is_reasoning_start_substr(self, text: str) -> bool:
|
||||
"""Check if a text matches one of the possible start reasoning seqs.
|
||||
|
||||
Args:
|
||||
text (str): Text to check for leading substr.
|
||||
|
||||
Returns:
|
||||
bool: True if any of the possible reasoning start seqs match.
|
||||
"""
|
||||
return any(
|
||||
think_start.startswith(text)
|
||||
for think_start in self.valid_think_starts)
|
||||
|
||||
def _is_response_start_substr(self, text: str) -> bool:
|
||||
"""Check if a text matches one of the possible start response seqs.
|
||||
|
||||
Args:
|
||||
text (str): Text to check for leading substr.
|
||||
|
||||
Returns:
|
||||
bool: True if any of the possible response start seqs match.
|
||||
"""
|
||||
return any(
|
||||
response_start.startswith(text)
|
||||
for response_start in self.valid_response_starts)
|
||||
|
||||
def _get_delta_message_with_no_reasoning_bounds(
|
||||
self,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
) -> DeltaMessage:
|
||||
"""Parse the delta message when the current text has not yet completed
|
||||
its start of reasoning sequence.
|
||||
|
||||
Args:
|
||||
current_text (str): The full previous + delta text.
|
||||
delta_text (str): Text to consider and parse content from.
|
||||
|
||||
Returns:
|
||||
DeltaMessage: Message containing the parsed content.
|
||||
"""
|
||||
prev_longest_length = len(current_text) - len(delta_text)
|
||||
is_substr = self._is_reasoning_start_substr(current_text)
|
||||
was_substr = self._is_reasoning_start_substr(
|
||||
current_text[:prev_longest_length])
|
||||
|
||||
# Check if we just generated something NOT in the special token seq;
|
||||
# if so, add everything that we previously skipped with this delta
|
||||
# message and append everything to content in the future.
|
||||
if was_substr and not is_substr:
|
||||
return DeltaMessage(
|
||||
reasoning_content=None,
|
||||
content=current_text,
|
||||
)
|
||||
if is_substr:
|
||||
# Might still be in the special token sequence; return nothing
|
||||
return DeltaMessage(reasoning_content=None, content=None)
|
||||
# Otherwise the sequence has already been broken and we already
|
||||
# corrected; just return the delta text as normal content.
|
||||
return DeltaMessage(reasoning_content=None, content=delta_text)
|
||||
|
||||
def _get_delta_message_with_no_response_bounds(
|
||||
self,
|
||||
current_text: str,
|
||||
reasoning_content: str,
|
||||
delta_text: str,
|
||||
) -> DeltaMessage:
|
||||
"""Parse the delta message when the current text has both reasoning
|
||||
content with no (response) content. NOTE that we may have overlapping
|
||||
tokens with the start of reasoning / start of response sequences on
|
||||
either side of the delta text.
|
||||
|
||||
Args:
|
||||
current_text (str): The full previous + delta text.
|
||||
reasoning_content (str): reasoning content from current_text.
|
||||
delta_text (str): Text to consider and parse content from.
|
||||
|
||||
Returns:
|
||||
DeltaMessage: Message containing the parsed content.
|
||||
"""
|
||||
# If we have no reasoning content or explicitly end with the start of
|
||||
# response sequence, we are in transition to the response; need to be
|
||||
# careful here, since the final token (:) will match the reasoning
|
||||
# content and fully parse it out; we should not pass the : back.
|
||||
ends_with_start_response_seq = any(
|
||||
current_text.endswith(response_start)
|
||||
for response_start in self.valid_response_starts)
|
||||
if reasoning_content is None or ends_with_start_response_seq:
|
||||
return DeltaMessage(reasoning_content=None, content=None)
|
||||
|
||||
# Consider previous / current text only within context of the reasoning
|
||||
previous_text = reasoning_content[:-len(delta_text)]
|
||||
current_text = reasoning_content
|
||||
|
||||
# We need to be careful about adding unfinished response sequences;
|
||||
# Find the place at which we MIGHT be starting a response sequence
|
||||
prev_idx = previous_text.rfind(self.seq_boundary_start)
|
||||
delta_idx = delta_text.rfind(self.seq_boundary_start)
|
||||
|
||||
# Check the state of potential start of response substring matches.
|
||||
prev_was_substr = self._is_response_start_substr(
|
||||
previous_text[prev_idx:]) if prev_idx >= 0 else False
|
||||
delta_continues_substr = self._is_response_start_substr(
|
||||
current_text[prev_idx:]) if prev_idx >= 0 else False
|
||||
delta_new_substr = self._is_response_start_substr(
|
||||
delta_text[delta_idx:]) if delta_idx >= 0 else False
|
||||
|
||||
# Delta only contains potential continued response sequence text.
|
||||
if delta_continues_substr:
|
||||
return DeltaMessage(reasoning_content=None, content=None)
|
||||
|
||||
if not prev_was_substr:
|
||||
# Delta may be starting a new response seq but has other text too.
|
||||
if delta_new_substr:
|
||||
return DeltaMessage(reasoning_content=delta_text[:delta_idx],
|
||||
content=None)
|
||||
# Normal case for most reasoning text (no potential special seqs).
|
||||
return DeltaMessage(reasoning_content=delta_text, content=None)
|
||||
# The substring that previously seemed to be a potential response
|
||||
# seq wasn't one; we need to add the content to the delta message,
|
||||
# and also slice off the potential response sequence
|
||||
elif delta_new_substr:
|
||||
reasoning_content = previous_text[
|
||||
prev_idx:] + delta_text[:delta_idx]
|
||||
return DeltaMessage(reasoning_content=reasoning_content,
|
||||
content=None)
|
||||
# No new substring yet, and we broke our old one; take the whole delta
|
||||
return DeltaMessage(
|
||||
reasoning_content=previous_text[prev_idx:] + delta_text,
|
||||
content=None,
|
||||
)
|
||||
|
||||
def _get_delta_message_with_both_bounds(
|
||||
self,
|
||||
delta_text: str,
|
||||
reasoning_content: str,
|
||||
response_content: str,
|
||||
current_text: str,
|
||||
response_seq_len: int,
|
||||
) -> DeltaMessage:
|
||||
"""Parse the delta message when the current text has both reasoning
|
||||
content and normal (response) content.
|
||||
|
||||
Args:
|
||||
delta_text (str): Text to consider and parse content from.
|
||||
reasoning_content (str): reasoning content from current_text.
|
||||
response_content (str): response content from current_text.
|
||||
current_text (str): The full previous + delta text.
|
||||
response_seq_len(str): Len of the complete response sequence used.
|
||||
|
||||
Returns:
|
||||
DeltaMessage: Message containing the parsed content.
|
||||
"""
|
||||
# Always have content; take length to the end
|
||||
delta_content = delta_text[-len(response_content):]
|
||||
reasoning_end_idx = len(delta_text) - (len(response_content) +
|
||||
response_seq_len)
|
||||
|
||||
if reasoning_end_idx < 0:
|
||||
delta_reasoning_content = None
|
||||
else:
|
||||
# Get the starting offset
|
||||
start_reasoning_content_idx = len(
|
||||
reasoning_content) + response_seq_len + len(
|
||||
response_content) - 1
|
||||
delta_offset = len(current_text) - len(delta_text)
|
||||
start_offset = start_reasoning_content_idx - delta_offset
|
||||
if start_offset < 0:
|
||||
start_offset = 0
|
||||
delta_reasoning_content = delta_text[
|
||||
start_offset:reasoning_end_idx]
|
||||
|
||||
return DeltaMessage(
|
||||
reasoning_content=delta_reasoning_content,
|
||||
content=delta_content,
|
||||
)
|
||||
|
||||
def _get_content_sections(
|
||||
self, current_text: str
|
||||
) -> tuple[Optional[str], Optional[int], Optional[str]]:
|
||||
"""Parse the text to extract the reasoning content / content
|
||||
if we have them.
|
||||
|
||||
Args:
|
||||
current_text (str): The full previous + delta text.
|
||||
|
||||
Returns:
|
||||
tuple[Optional[str], Optional[int], Optional[str]]: Tuple of len 3
|
||||
containing the reasoning content, the length of the response seq
|
||||
(if there is one) and the non-reasoning content.
|
||||
"""
|
||||
current_chunk_start = 0
|
||||
start_reasoning_content = None
|
||||
parsed_content = False
|
||||
delimiter_idxs = [
|
||||
idx for idx, char in enumerate(current_text)
|
||||
if char == self.seq_boundary_end
|
||||
]
|
||||
|
||||
for current_chunk_end in delimiter_idxs:
|
||||
current_chunk = current_text[current_chunk_start:current_chunk_end]
|
||||
# Check to see if the start of reasoning seq if complete
|
||||
if start_reasoning_content is None:
|
||||
for think_start in self.valid_think_starts:
|
||||
if current_chunk == think_start[:-1]:
|
||||
start_reasoning_content = current_chunk_end + 1
|
||||
current_chunk_start = current_chunk_end + 1
|
||||
break
|
||||
|
||||
# Check to see if the start of response seq if complete
|
||||
elif not parsed_content:
|
||||
for response_start in self.valid_response_starts:
|
||||
if current_chunk[-len(response_start) +
|
||||
1:] == response_start[:-1]:
|
||||
# Mark end of reasoning and start response content
|
||||
# after the start of response sequence.
|
||||
end_reasoning_content = current_chunk_end - len(
|
||||
response_start)
|
||||
reasoning_content = current_text[
|
||||
start_reasoning_content:end_reasoning_content]
|
||||
response_content = current_text[current_chunk_end + 1:]
|
||||
return reasoning_content, len(
|
||||
response_start), response_content
|
||||
|
||||
if start_reasoning_content and not parsed_content:
|
||||
return current_text[start_reasoning_content:], None, None
|
||||
return None, None, None
|
@ -19,6 +19,10 @@ def get_reasoner(tokenizer: PreTrainedTokenizer,
|
||||
return None
|
||||
elif reasoning_backend == "deepseek_r1":
|
||||
return DeepSeekReasoner.from_tokenizer(tokenizer)
|
||||
elif reasoning_backend == "granite":
|
||||
logger.warning(
|
||||
"Granite reasoner not yet implemented for structured outputs")
|
||||
return None
|
||||
else:
|
||||
# Raise a warning for unknown reasoning backend and return None
|
||||
# We cannot raise an error here because some reasoning models
|
||||
|
Loading…
x
Reference in New Issue
Block a user