diff --git a/docs/source/features/reasoning_outputs.md b/docs/source/features/reasoning_outputs.md new file mode 100644 index 00000000..e39bbacf --- /dev/null +++ b/docs/source/features/reasoning_outputs.md @@ -0,0 +1,151 @@ +(reasoning-outputs)= + +# Reasoning Outputs + +vLLM offers support for reasoning models like [DeepSeek R1](https://huggingface.co/deepseek-ai/DeepSeek-R1), which are designed to generate outputs containing both reasoning steps and final conclusions. + +Reasoning models return a additional `reasoning_content` field in their outputs, which contains the reasoning steps that led to the final conclusion. This field is not present in the outputs of other models. + +## Supported Models + +vLLM currently supports the following reasoning models: + +- [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) (`deepseek_r1`, which looks for ` ... `) + +## Quickstart + +To use reasoning models, you need to specify the `--enable-reasoning` and `--reasoning-parser` flags when making a request to the chat completion endpoint. The `--reasoning-parser` flag specifies the reasoning parser to use for extracting reasoning content from the model output. + +```bash +vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \ + --enable-reasoning --reasoning-parser deepseek_r1 +``` + +Next, make a request to the model that should return the reasoning content in the response. + +```python +from openai import OpenAI + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" + +client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, +) + +models = client.models.list() +model = models.data[0].id + +# Round 1 +messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}] +response = client.chat.completions.create(model=model, messages=messages) + +reasoning_content = response.choices[0].message.reasoning_content +content = response.choices[0].message.content + +print("reasoning_content:", reasoning_content) +print("content:", content) +``` + +The `reasoning_content` field contains the reasoning steps that led to the final conclusion, while the `content` field contains the final conclusion. + +## Streaming chat completions + +Streaming chat completions are also supported for reasoning models. The `reasoning_content` field is available in the `delta` field in [chat completion response chunks](https://platform.openai.com/docs/api-reference/chat/streaming). + +```json +{ + "id": "chatcmpl-123", + "object": "chat.completion.chunk", + "created": 1694268190, + "model": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + "system_fingerprint": "fp_44709d6fcb", + "choices": [ + { + "index": 0, + "delta": { + "role": "assistant", + "reasoning_content": "is", + }, + "logprobs": null, + "finish_reason": null + } + ] +} +``` + +Please note that it is not compatible with the OpenAI Python client library. You can use the `requests` library to make streaming requests. + +## How to support a new reasoning model + +You can add a new `ReasoningParser` similar to `vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py`. + +```python +# import the required packages + +from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import ( + ReasoningParser, ReasoningParserManager) +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage) + +# define a reasoning parser and register it to vllm +# the name list in register_module can be used +# in --reasoning-parser. +@ReasoningParserManager.register_module(["example"]) +class ExampleParser(ReasoningParser): + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + """ + Instance method that should be implemented for extracting reasoning + from an incomplete response; for use when handling reasoning calls and + streaming. Has to be an instance method because it requires state - + the current tokens/diffs, but also the information about what has + previously been parsed and extracted (see constructor) + """ + + def extract_reasoning_content( + self, model_output: str, request: ChatCompletionRequest + ) -> Tuple[Optional[str], Optional[str]]: + """ + Extract reasoning content from a complete model-generated string. + + Used for non-streaming responses where we have the entire model response + available before sending to the client. + + Parameters: + model_output: str + The model-generated string to extract reasoning content from. + + request: ChatCompletionRequest + The request object that was used to generate the model_output. + + Returns: + Tuple[Optional[str], Optional[str]] + A tuple containing the reasoning content and the content. + """ +``` + +After defining the reasoning parser, you can use it by specifying the `--reasoning-parser` flag when making a request to the chat completion endpoint. + +```bash +vllm serve \ + --enable-reasoning --reasoning-parser example +``` + +## Limitations + +- The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`). +- It is not compatible with the [`structured_outputs`](#structured_outputs) and [`tool_calling`](#tool_calling) features. +- The reasoning content is not available for all models. Check the model's documentation to see if it supports reasoning. diff --git a/docs/source/index.md b/docs/source/index.md index 2c302d3f..6957d5dd 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -90,6 +90,7 @@ models/extensions/index features/quantization/index features/lora features/tool_calling +features/reasoning_outputs features/structured_outputs features/automatic_prefix_caching features/disagg_prefill diff --git a/examples/online_serving/openai_chat_completion_with_reasoning.py b/examples/online_serving/openai_chat_completion_with_reasoning.py new file mode 100644 index 00000000..83e51a48 --- /dev/null +++ b/examples/online_serving/openai_chat_completion_with_reasoning.py @@ -0,0 +1,53 @@ +""" +An example shows how to generate chat completions from reasoning models +like DeepSeekR1. + +To run this example, you need to start the vLLM server with the reasoning +parser: + +```bash +vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \ + --enable-reasoning --reasoning-parser deepseek_r1 +``` + +This example demonstrates how to generate chat completions from reasoning models +using the OpenAI Python client library. +""" + +from openai import OpenAI + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" + +client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, +) + +models = client.models.list() +model = models.data[0].id + +# Round 1 +messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}] +response = client.chat.completions.create(model=model, messages=messages) + +reasoning_content = response.choices[0].message.reasoning_content +content = response.choices[0].message.content + +print("reasoning_content:", reasoning_content) +print("content:", content) + +# Round 2 +messages.append({"role": "assistant", "content": content}) +messages.append({ + "role": "user", + "content": "How many Rs are there in the word 'strawberry'?", +}) +response = client.chat.completions.create(model=model, messages=messages) + +reasoning_content = response.choices[0].message.reasoning_content +content = response.choices[0].message.content + +print("reasoning_content:", reasoning_content) +print("content:", content) diff --git a/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py b/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py new file mode 100644 index 00000000..8c14aac6 --- /dev/null +++ b/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py @@ -0,0 +1,90 @@ +""" +An example shows how to generate chat completions from reasoning models +like DeepSeekR1. + +To run this example, you need to start the vLLM server with the reasoning +parser: + +```bash +vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \ + --enable-reasoning --reasoning-parser deepseek_r1 +``` + +Unlike openai_chat_completion_with_reasoning.py, this example demonstrates the +streaming chat completions feature. + +The streaming chat completions feature allows you to receive chat completions +in real-time as they are generated by the model. This is useful for scenarios +where you want to display chat completions to the user as they are generated +by the model. + +Here we do not use the OpenAI Python client library, because it does not support +`reasoning_content` fields in the response. +""" + +import json + +import requests + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" + +models = requests.get( + f"{openai_api_base}/models", + headers={ + "Authorization": f"Bearer {openai_api_key}" + }, +).json() +model = models["data"][0]["id"] + +# Streaming chat completions +messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}] + +response = requests.post( + f"{openai_api_base}/chat/completions", + headers={"Authorization": f"Bearer {openai_api_key}"}, + json={ + "model": model, + "messages": messages, + "stream": True + }, +) + +print("client: Start streaming chat completions...") +printed_reasoning_content = False +printed_content = False +# Make the streaming request +if response.status_code == 200: + # Process the streaming response + for line in response.iter_lines(): + if line: # Filter out keep-alive new lines + # Decode the line and parse the JSON + decoded_line = line.decode("utf-8") + if decoded_line.startswith("data:"): + data = decoded_line[5:].strip() # Remove "data:" prefix + if data == "[DONE]": # End of stream + print("\nclient: Stream completed.") + break + try: + # Parse the JSON data + chunk = json.loads(data) + reasoning_content = chunk["choices"][0]["delta"].get( + "reasoning_content", "") + content = chunk["choices"][0]["delta"].get("content", "") + + if reasoning_content: + if not printed_reasoning_content: + printed_reasoning_content = True + print("reasoning_content:", end="", flush=True) + print(reasoning_content, end="", flush=True) + elif content: + if not printed_content: + printed_content = True + print("\ncontent:", end="", flush=True) + # Extract and print the content + print(content, end="", flush=True) + except json.JSONDecodeError: + print("Error decoding JSON:", decoded_line) +else: + print(f"Error: {response.status_code} - {response.text}") diff --git a/tests/entrypoints/openai/reasoning_parsers/__init__.py b/tests/entrypoints/openai/reasoning_parsers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py b/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py new file mode 100644 index 00000000..4607e4df --- /dev/null +++ b/tests/entrypoints/openai/reasoning_parsers/test_deepseekr1_reasoning_parser.py @@ -0,0 +1,120 @@ +from typing import List + +import pytest +from transformers import AutoTokenizer + +from tests.entrypoints.openai.reasoning_parsers.utils import ( + run_reasoning_extraction) +from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser, + ReasoningParserManager) + +parser_name = "deepseek_r1" +start_token = "" +end_token = "" + +SIMPLE_REASONING = { + "output": "This is a reasoning sectionThis is the rest", + "reasoning_content": "This is a reasoning section", + "content": "This is the rest", +} +COMPLETE_REASONING = { + "output": "This is a reasoning section", + "reasoning_content": "This is a reasoning section", + "content": None, +} +NO_REASONING = { + "output": "This is a reasoning section", + "reasoning_content": None, + "content": "This is a reasoning section", +} +MULTIPLE_LINES = { + "output": "This\nThatThis is the rest\nThat", + "reasoning_content": "This\nThat", + "content": "This is the rest\nThat", +} +SHORTEST_REASONING_NO_STREAMING = { + "output": "This is the rest", + "reasoning_content": "", + "content": "This is the rest", +} +SHORTEST_REASONING = { + "output": "This is the rest", + "reasoning_content": None, + "content": "This is the rest", +} + +TEST_CASES = [ + pytest.param( + False, + SIMPLE_REASONING, + id="simple_streaming", + ), + pytest.param( + True, + SIMPLE_REASONING, + id="simple_streaming", + ), + pytest.param( + False, + COMPLETE_REASONING, + id="complete_streaming", + ), + pytest.param( + True, + COMPLETE_REASONING, + id="complete_streaming", + ), + pytest.param( + False, + NO_REASONING, + id="no_streaming", + ), + pytest.param( + True, + NO_REASONING, + id="no_streaming", + ), + pytest.param( + False, + MULTIPLE_LINES, + id="multiple_lines_streaming", + ), + pytest.param( + True, + MULTIPLE_LINES, + id="multiple_lines_streaming", + ), + pytest.param( + True, + SHORTEST_REASONING, + id="shortest_streaming", + ), + pytest.param( + False, + SHORTEST_REASONING_NO_STREAMING, + id="shortest_streaming", + ), +] + + +@pytest.mark.parametrize("streaming, param_dict", TEST_CASES) +def test_reasoning( + streaming: bool, + param_dict: dict, +): + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") + tokenizer.add_tokens([start_token, end_token]) + output = tokenizer.tokenize(param_dict["output"]) + # decode everything to tokens + output_tokens: List[str] = [ + tokenizer.convert_tokens_to_string([token]) for token in output + ] + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser( + parser_name)(tokenizer) + + reasoning, content = run_reasoning_extraction(parser, + output_tokens, + streaming=streaming) + + assert reasoning == param_dict["reasoning_content"] + assert content == param_dict["content"] diff --git a/tests/entrypoints/openai/reasoning_parsers/utils.py b/tests/entrypoints/openai/reasoning_parsers/utils.py new file mode 100644 index 00000000..ac73ad50 --- /dev/null +++ b/tests/entrypoints/openai/reasoning_parsers/utils.py @@ -0,0 +1,93 @@ +from typing import List, Optional, Tuple, Union + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage) +from vllm.entrypoints.openai.reasoning_parsers import ReasoningParser + + +class StreamingReasoningReconstructor: + + def __init__(self): + self.reasoning_content = None + self.other_content = None + + def append_delta(self, delta: DeltaMessage): + # content and the reasoning content should not be present + # at the same time + assert delta.content is None or delta.reasoning_content is None, ( + "Both content and reasoning content are present in the " + "delta message") + if delta.content is not None: + if self.other_content is None: + self.other_content = delta.content + else: + self.other_content += delta.content + else: + if self.reasoning_content is None: + self.reasoning_content = delta.reasoning_content + else: + self.reasoning_content += delta.reasoning_content + + +def run_reasoning_extraction( + reasoning_parser: ReasoningParser, + model_output: List[str], + request: Union[ChatCompletionRequest, None] = None, + streaming: bool = False, +) -> Tuple[Optional[str], Optional[str]]: + if streaming: + reconstructor = run_reasoning_extraction_streaming( + reasoning_parser, + model_output, + request, + ) + return ( + reconstructor.reasoning_content, + reconstructor.other_content or None, + ) + else: + reasoning, content = run_reasoning_extraction_nonstreaming( + reasoning_parser, model_output, request) + return reasoning, content + + +def run_reasoning_extraction_nonstreaming( + reasoning_parser: ReasoningParser, + model_output: List[str], + request: Union[ChatCompletionRequest, None] = None, +) -> Tuple[Optional[str], Optional[str]]: + request = request or ChatCompletionRequest(messages=[], model="test-model") + return reasoning_parser.extract_reasoning_content( + model_output=''.join(model_output), request=request) + + +def run_reasoning_extraction_streaming( + reasoning_parser: ReasoningParser, + model_deltas: List[str], + request: Union[ChatCompletionRequest, None] = None, +) -> StreamingReasoningReconstructor: + request = request or ChatCompletionRequest(messages=[], model="test-model") + reconstructor = StreamingReasoningReconstructor() + previous_text = "" + previous_tokens: List[int] = [] + for delta in model_deltas: + token_delta = [ + reasoning_parser.vocab.get(token) + for token in reasoning_parser.model_tokenizer.tokenize(delta) + if token in reasoning_parser.vocab + ] + current_text = previous_text + delta + current_tokens = previous_tokens + token_delta + delta_message = reasoning_parser.extract_reasoning_content_streaming( + previous_text, + current_text, + delta, + previous_tokens, + current_tokens, + token_delta, + ) + if delta_message is not None: + reconstructor.append_delta(delta_message) + previous_text = current_text + previous_tokens = current_tokens + return reconstructor diff --git a/tests/entrypoints/openai/test_cli_args.py b/tests/entrypoints/openai/test_cli_args.py index e49562ad..01bcd78a 100644 --- a/tests/entrypoints/openai/test_cli_args.py +++ b/tests/entrypoints/openai/test_cli_args.py @@ -116,6 +116,35 @@ def test_enable_auto_choice_passes_with_tool_call_parser(serve_parser): validate_parsed_serve_args(args) +def test_enable_auto_choice_fails_with_enable_reasoning(serve_parser): + """Ensure validation fails if reasoning is enabled with auto tool choice""" + args = serve_parser.parse_args(args=[ + "--enable-auto-tool-choice", + "--enable-reasoning", + ]) + with pytest.raises(TypeError): + validate_parsed_serve_args(args) + + +def test_enable_reasoning_passes_with_reasoning_parser(serve_parser): + """Ensure validation passes if reasoning is enabled + with a reasoning parser""" + args = serve_parser.parse_args(args=[ + "--enable-reasoning", + "--reasoning-parser", + "deepseek_r1", + ]) + validate_parsed_serve_args(args) + + +def test_enable_reasoning_fails_without_reasoning_parser(serve_parser): + """Ensure validation fails if reasoning is enabled + without a reasoning parser""" + args = serve_parser.parse_args(args=["--enable-reasoning"]) + with pytest.raises(TypeError): + validate_parsed_serve_args(args) + + def test_chat_template_validation_for_happy_paths(serve_parser): """Ensure validation passes if the chat template exists""" args = serve_parser.parse_args( diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 077bc993..9e5cf4ba 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -61,6 +61,7 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, TokenizeRequest, TokenizeResponse, UnloadLoraAdapterRequest) +from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager # yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion @@ -771,6 +772,8 @@ async def init_app_state( return_tokens_as_token_ids=args.return_tokens_as_token_ids, enable_auto_tools=args.enable_auto_tool_choice, tool_parser=args.tool_call_parser, + enable_reasoning=args.enable_reasoning, + reasoning_parser=args.reasoning_parser, enable_prompt_tokens_details=args.enable_prompt_tokens_details, ) if model_config.runner_type == "generate" else None state.openai_serving_completion = OpenAIServingCompletion( @@ -844,6 +847,13 @@ async def run_server(args, **uvicorn_kwargs) -> None: raise KeyError(f"invalid tool call parser: {args.tool_call_parser} " f"(chose from {{ {','.join(valid_tool_parses)} }})") + valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys() + if args.enable_reasoning \ + and args.reasoning_parser not in valid_reasoning_parses: + raise KeyError( + f"invalid reasoning parser: {args.reasoning_parser} " + f"(chose from {{ {','.join(valid_reasoning_parses)} }})") + # workaround to make sure that we bind the port before the engine is set up. # This avoids race conditions with ray. # see https://github.com/vllm-project/vllm/issues/8204 diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 4df75a66..9cfe07c6 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -12,6 +12,7 @@ from typing import List, Optional, Sequence, Union, get_args from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption, validate_chat_template) +from vllm.entrypoints.openai.reasoning_parsers import ReasoningParserManager from vllm.entrypoints.openai.serving_models import (LoRAModulePath, PromptAdapterPath) from vllm.entrypoints.openai.tool_parsers import ToolParserManager @@ -208,6 +209,23 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=False, help="Enable auto tool choice for supported models. Use " "``--tool-call-parser`` to specify which parser to use.") + parser.add_argument( + "--enable-reasoning", + action="store_true", + default=False, + help="Whether to enable reasoning_content for the model. " + "If enabled, the model will be able to generate reasoning content.") + + valid_reasoning_parsers = ReasoningParserManager.reasoning_parsers.keys() + parser.add_argument( + "--reasoning-parser", + type=str, + metavar="{" + ",".join(valid_reasoning_parsers) + "}", + default=None, + help= + "Select the reasoning parser depending on the model that you're using." + " This is used to parse the reasoning content into OpenAI API " + "format. Required for ``--enable-reasoning``.") valid_tool_parsers = ToolParserManager.tool_parsers.keys() parser.add_argument( @@ -267,6 +285,18 @@ def validate_parsed_serve_args(args: argparse.Namespace): raise TypeError("Error: --enable-auto-tool-choice requires " "--tool-call-parser") + # Enable reasoning needs a reasoning parser to be valid + if args.enable_reasoning and not args.reasoning_parser: + raise TypeError("Error: --enable-reasoning requires " + "--reasoning-parser") + + # Ref https://api-docs.deepseek.com/guides/reasoning_model + # tool call and reasoning cannot be enabled at the same time. + if args.enable_auto_tool_choice and args.enable_reasoning: + raise TypeError( + "Error: --enable-auto-tool-choice and " + "--enable-reasoning cannot be enabled at the same time") + def create_parser_for_docs() -> FlexibleArgumentParser: parser_for_docs = FlexibleArgumentParser( diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index f89c3f42..2bc136cc 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1202,6 +1202,7 @@ class ExtractedToolCallInformation(BaseModel): class ChatMessage(OpenAIBaseModel): role: str + reasoning_content: Optional[str] = None content: Optional[str] = None tool_calls: List[ToolCall] = Field(default_factory=list) @@ -1243,6 +1244,7 @@ class ChatCompletionResponse(OpenAIBaseModel): class DeltaMessage(OpenAIBaseModel): role: Optional[str] = None content: Optional[str] = None + reasoning_content: Optional[str] = None tool_calls: List[DeltaToolCall] = Field(default_factory=list) diff --git a/vllm/entrypoints/openai/reasoning_parsers/__init__.py b/vllm/entrypoints/openai/reasoning_parsers/__init__.py new file mode 100644 index 00000000..a21bff52 --- /dev/null +++ b/vllm/entrypoints/openai/reasoning_parsers/__init__.py @@ -0,0 +1,6 @@ +from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager +from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser + +__all__ = [ + "ReasoningParser", "ReasoningParserManager", "DeepSeekR1ReasoningParser" +] diff --git a/vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py b/vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py new file mode 100644 index 00000000..e5d10ee0 --- /dev/null +++ b/vllm/entrypoints/openai/reasoning_parsers/abs_reasoning_parsers.py @@ -0,0 +1,158 @@ +import os +from functools import cached_property +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Type, Union + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import import_from_path, is_list_of + +logger = init_logger(__name__) + + +class ReasoningParser: + """ + Abstract reasoning parser class that should not be used directly. + Provided and methods should be used in derived classes. + + It is used to extract reasoning content from the model output. + """ + + def __init__(self, tokenizer: AnyTokenizer): + self.model_tokenizer = tokenizer + + @cached_property + def vocab(self) -> Dict[str, int]: + # NOTE: Only PreTrainedTokenizerFast is guaranteed to have .vocab + # whereas all tokenizers have .get_vocab() + return self.model_tokenizer.get_vocab() + + def extract_reasoning_content( + self, model_output: str, request: ChatCompletionRequest + ) -> Tuple[Optional[str], Optional[str]]: + """ + Extract reasoning content from a complete model-generated string. + + Used for non-streaming responses where we have the entire model response + available before sending to the client. + + Parameters: + model_output: str + The model-generated string to extract reasoning content from. + + request: ChatCompletionRequest + The request object that was used to generate the model_output. + + Returns: + Tuple[Optional[str], Optional[str]] + A tuple containing the reasoning content and the content. + """ + + raise NotImplementedError( + "AbstractReasoningParser.extract_reasoning_calls " + "has not been implemented!") + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + """ + Instance method that should be implemented for extracting reasoning + from an incomplete response; for use when handling reasoning calls and + streaming. Has to be an instance method because it requires state - + the current tokens/diffs, but also the information about what has + previously been parsed and extracted (see constructor) + """ + raise NotImplementedError( + "AbstractReasoningParser.extract_reasoning_content_streaming " + "has not been implemented!") + + +class ReasoningParserManager: + reasoning_parsers: Dict[str, Type] = {} + + @classmethod + def get_reasoning_parser(cls, name) -> Type: + """ + Get reasoning parser by name which is registered by `register_module`. + + Raise a KeyError exception if the name is not registered. + """ + if name in cls.reasoning_parsers: + return cls.reasoning_parsers[name] + + raise KeyError(f"reasoning helper: '{name}' not found in " + "reasoning_parsers") + + @classmethod + def _register_module(cls, + module: Type, + module_name: Optional[Union[str, List[str]]] = None, + force: bool = True) -> None: + if not issubclass(module, ReasoningParser): + raise TypeError("module must be subclass of ReasoningParser, " + f"but got {type(module)}") + if module_name is None: + module_name = module.__name__ + if isinstance(module_name, str): + module_name = [module_name] + for name in module_name: + if not force and name in cls.reasoning_parsers: + existed_module = cls.reasoning_parsers[name] + raise KeyError(f"{name} is already registered " + f"at {existed_module.__module__}") + cls.reasoning_parsers[name] = module + + @classmethod + def register_module( + cls, + name: Optional[Union[str, List[str]]] = None, + force: bool = True, + module: Union[Type, None] = None) -> Union[type, Callable]: + """ + Register module with the given name or name list. it can be used as a + decoder(with module as None) or normal function(with module as not + None). + """ + if not isinstance(force, bool): + raise TypeError(f"force must be a boolean, but got {type(force)}") + + # raise the error ahead of time + if not (name is None or isinstance(name, str) + or is_list_of(name, str)): + raise TypeError( + "name must be None, an instance of str, or a sequence of str, " + f"but got {type(name)}") + + # use it as a normal method: x.register_module(module=SomeClass) + if module is not None: + cls._register_module(module=module, module_name=name, force=force) + return module + + # use it as a decorator: @x.register_module() + def _register(module): + cls._register_module(module=module, module_name=name, force=force) + return module + + return _register + + @classmethod + def import_reasoning_parser(cls, plugin_path: str) -> None: + """ + Import a user-defined reasoning parser by the path + of the reasoning parser define file. + """ + module_name = os.path.splitext(os.path.basename(plugin_path))[0] + + try: + import_from_path(module_name, plugin_path) + except Exception: + logger.exception("Failed to load module '%s' from %s.", + module_name, plugin_path) + return diff --git a/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py b/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py new file mode 100644 index 00000000..a440ddc8 --- /dev/null +++ b/vllm/entrypoints/openai/reasoning_parsers/deepseek_r1_reasoning_parser.py @@ -0,0 +1,133 @@ +import re +from typing import Optional, Sequence, Tuple, Union + +from transformers import PreTrainedTokenizerBase + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage) +from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import ( + ReasoningParser, ReasoningParserManager) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +@ReasoningParserManager.register_module("deepseek_r1") +class DeepSeekR1ReasoningParser(ReasoningParser): + """ + Reasoning parser for DeepSeek R1 model. + + The DeepSeek R1 model uses ... tokens to denote reasoning + text. This parser extracts the reasoning content from the model output. + """ + + def __init__(self, tokenizer: PreTrainedTokenizerBase): + super().__init__(tokenizer) + self.think_start_token = "" + self.think_end_token = "" + + self.reasoning_regex = re.compile( + rf"{self.think_start_token}(.*?){self.think_end_token}", re.DOTALL) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ReasoningParser " + "constructor during construction.") + + self.think_start_token_id = self.vocab.get(self.think_start_token) + self.think_end_token_id = self.vocab.get(self.think_end_token) + if (self.think_start_token_id is None + or self.think_end_token_id is None): + raise RuntimeError( + "DeepSeek R1 reasoning parser could not locate think start/end " + "tokens in the tokenizer!") + + def extract_reasoning_content_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + """ + Extract reasoning content from a delta message. + Handles streaming output where previous + delta = current. + Uses token IDs for faster processing. + For text abcxyz: + - 'abc' goes to reasoning_content + - 'xyz' goes to content + """ + # Skip single special tokens + if len(delta_token_ids) == 1 and (delta_token_ids[0] in [ + self.think_start_token_id, self.think_end_token_id + ]): + return None + + if self.think_start_token_id in previous_token_ids: + if self.think_end_token_id in delta_token_ids: + # in previous, in delta, + # extract reasoning content + end_index = delta_text.find(self.think_end_token) + reasoning_content = delta_text[:end_index] + content = delta_text[end_index + len(self.think_end_token):] + return DeltaMessage(reasoning_content=reasoning_content, + content=content if content else None) + elif self.think_end_token_id in previous_token_ids: + # in previous, in previous, + # reasoning content continues + return DeltaMessage(content=delta_text) + else: + # in previous, no in previous or delta, + # reasoning content continues + return DeltaMessage(reasoning_content=delta_text) + elif self.think_start_token_id in delta_token_ids: + logger.info(delta_text) + if self.think_end_token_id in delta_token_ids: + # in delta, in delta, extract reasoning content + start_index = delta_text.find(self.think_start_token) + end_index = delta_text.find(self.think_end_token) + reasoning_content = delta_text[start_index + + len(self.think_start_token + ):end_index] + content = delta_text[end_index + len(self.think_end_token):] + return DeltaMessage(reasoning_content=reasoning_content, + content=content if content else None) + else: + # in delta, no in delta, + # reasoning content continues + return DeltaMessage(reasoning_content=delta_text) + else: + # No in previous or delta, reasoning content continues. + return DeltaMessage(content=delta_text) + + def extract_reasoning_content( + self, model_output: str, request: ChatCompletionRequest + ) -> Tuple[Optional[str], Optional[str]]: + + # Check if the model output contains the tokens. + if (self.think_start_token not in model_output + or self.think_end_token not in model_output): + return None, model_output + else: + # Use a regex to find the reasoning content + reasoning_content = self.reasoning_regex.findall(model_output)[0] + + # Remove the reasoning content from the model output + # Although deepseek's token is always at the + # beginning of the line, we cannot guarantee that the + # other models will follow this convention. + # Therefore, we need to add :start_index. + start_index = model_output.find(self.think_start_token) + if start_index != -1: + end_index = start_index + len( + f"{self.think_start_token}{reasoning_content}{self.think_end_token}" + ) + model_output = model_output[:start_index] + \ + model_output[end_index:] + + if len(model_output) == 0: + return reasoning_content, None + + return reasoning_content, model_output diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 89a119ac..dc97f0eb 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -21,6 +21,8 @@ from vllm.entrypoints.openai.protocol import ( ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo, RequestResponseMetadata, ToolCall, UsageInfo) +from vllm.entrypoints.openai.reasoning_parsers import (ReasoningParser, + ReasoningParserManager) from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager @@ -47,6 +49,8 @@ class OpenAIServingChat(OpenAIServing): chat_template: Optional[str], chat_template_content_format: ChatTemplateContentFormatOption, return_tokens_as_token_ids: bool = False, + enable_reasoning: bool = False, + reasoning_parser: Optional[str] = None, enable_auto_tools: bool = False, tool_parser: Optional[str] = None, enable_prompt_tokens_details: bool = False, @@ -69,6 +73,18 @@ class OpenAIServingChat(OpenAIServing): " the parallel_tool_calls client option is preset for " "compatibility reasons, it will be ignored.") + self.enable_reasoning: bool = enable_reasoning + self.reasoning_parser: Optional[Callable[[AnyTokenizer], + ReasoningParser]] = None + if self.enable_reasoning: + try: + self.reasoning_parser = ( + ReasoningParserManager.get_reasoning_parser( + reasoning_parser)) + except Exception as e: + raise TypeError("Error: --enable-reasoning requires " + f"reasoning_parser:'{reasoning_parser}' " + "which has not been registered") from e self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None if self.enable_auto_tools: try: @@ -285,14 +301,35 @@ class OpenAIServingChat(OpenAIServing): not tool_choice_function_name and self._should_stream_with_auto_tool_parsing(request)) + should_stream_with_reasoning_parsing = ( + self._should_stream_with_reasoning_parsing(request)) + all_previous_token_ids: Optional[List[List[int]]] - if tool_choice_auto: + + # Only one of these will be used, thus previous_texts and + # all_previous_token_ids will not be used twice in the same iteration. + if tool_choice_auto or should_stream_with_reasoning_parsing: # These are only required in "auto" tool choice case previous_texts = [""] * num_choices all_previous_token_ids = [[]] * num_choices else: previous_texts, all_previous_token_ids = None, None + try: + # There is no need to check if the reasoning_parser is None + # because the should_stream_with_reasoning_parsing check + # already ensures that the reasoning_parser is not None. + # but the pre-commit hook requires it. + if should_stream_with_reasoning_parsing and \ + self.reasoning_parser is not None: + reasoning_parser = self.reasoning_parser(tokenizer) + except RuntimeError as e: + logger.exception("Error in reasoning parser creation.") + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + yield "data: [DONE]\n\n" + return + # Prepare the tool parser if it's needed try: if tool_choice_auto and self.tool_parser: @@ -456,6 +493,32 @@ class OpenAIServingChat(OpenAIServing): # update the previous values for the next iteration previous_texts[i] = current_text all_previous_token_ids[i] = current_token_ids + # reasoning_content cannot be enabled with tool_choice. + # If it is, the tool_choice will be used instead. + elif self.enable_reasoning: + # handle reasoning_content delta + assert reasoning_parser is not None + assert previous_texts is not None + assert all_previous_token_ids is not None + previous_text = previous_texts[i] + previous_token_ids = all_previous_token_ids[i] + current_text = previous_text + delta_text + current_token_ids = previous_token_ids + list( + output.token_ids) + + delta_message = (reasoning_parser. + extract_reasoning_content_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + output.token_ids, + )) + + # update the previous values for the next iteration + previous_texts[i] = current_text + all_previous_token_ids[i] = current_token_ids # handle streaming just a content delta else: @@ -642,17 +705,38 @@ class OpenAIServingChat(OpenAIServing): else: logprobs = None + should_stream_with_reasoning_parsing = ( + self._should_stream_with_reasoning_parsing(request)) + # In the OpenAI API the finish_reason is "tools_called" # if the tool choice is auto and the model produced a tool # call. The same is not true for named function calls auto_tools_called = False + if should_stream_with_reasoning_parsing and \ + self.reasoning_parser is not None: + try: + reasoning_parser = self.reasoning_parser(tokenizer) + except RuntimeError as e: + logger.exception("Error in reasoning parser creation.") + return self.create_error_response(str(e)) + + reasoning_content, content = ( + reasoning_parser.extract_reasoning_content( + output.text, request=request)) + + if reasoning_content: + message = ChatMessage(role=role, + content=content, + reasoning_content=reasoning_content) + else: + message = ChatMessage(role=role, content=output.text) + # if auto tools are not enabled, and a named tool choice using # outlines is not being used - if (not self.enable_auto_tools - or not self.tool_parser) and not isinstance( - request.tool_choice, - ChatCompletionNamedToolChoiceParam): + elif (not self.enable_auto_tools + or not self.tool_parser) and not isinstance( + request.tool_choice, ChatCompletionNamedToolChoiceParam): message = ChatMessage(role=role, content=output.text) # if the request uses tools and specified a tool choice @@ -835,6 +919,17 @@ class OpenAIServingChat(OpenAIServing): return (request.tools and self.tool_parser and self.enable_auto_tools and request.tool_choice in ['auto', None]) + def _should_stream_with_reasoning_parsing(self, + request: ChatCompletionRequest): + """ + Utility function to check if streamed tokens should go through the + reasoning parser that was configured. + + We only want to do this IF reasoning is enabled and a reasoning + parser is configured. + """ + return self.enable_reasoning and self.reasoning_parser is not None + def _should_check_for_unstreamed_tool_arg_tokens( self, delta_message: Optional[DeltaMessage], diff --git a/vllm/scripts.py b/vllm/scripts.py index 42e1c639..8101e6b3 100644 --- a/vllm/scripts.py +++ b/vllm/scripts.py @@ -167,6 +167,7 @@ def main(): "Must be a YAML with the following options:" "https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#cli-reference" ) + serve_parser = make_arg_parser(serve_parser) serve_parser.set_defaults(dispatch_function=serve)