From 87d41c849d2cde9279fb08a3a0d97123e3d8fe2f Mon Sep 17 00:00:00 2001 From: Breno Faria Date: Thu, 30 May 2024 11:52:14 +0200 Subject: [PATCH] [BUGFIX] [FRONTEND] Correct chat logprobs (#5029) Co-authored-by: Breno Faria --- tests/async_engine/test_openapi_server_ray.py | 6 +- tests/entrypoints/test_openai_server.py | 211 +++++++++++++++--- vllm/entrypoints/openai/protocol.py | 50 ++++- vllm/entrypoints/openai/serving_chat.py | 68 +++++- vllm/entrypoints/openai/serving_completion.py | 74 +++++- vllm/entrypoints/openai/serving_engine.py | 52 +---- 6 files changed, 362 insertions(+), 99 deletions(-) diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server_ray.py index 7a8d4b39..4c362a05 100644 --- a/tests/async_engine/test_openapi_server_ray.py +++ b/tests/async_engine/test_openapi_server_ray.py @@ -94,8 +94,10 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI): chat_completion.choices) == 1 assert chat_completion.choices[0].message is not None assert chat_completion.choices[0].logprobs is not None - assert chat_completion.choices[0].logprobs.top_logprobs is not None - assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5 + assert chat_completion.choices[0].logprobs.content[ + 0].top_logprobs is not None + assert len( + chat_completion.choices[0].logprobs.content[0].top_logprobs) == 5 message = chat_completion.choices[0].message assert message.content is not None and len(message.content) >= 10 assert message.role == "assistant" diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 2463ccde..97213703 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -184,6 +184,26 @@ async def test_single_completion(server, client: openai.AsyncOpenAI, completion.choices[0].text) >= 5 +@pytest.mark.asyncio +@pytest.mark.parametrize( + # first test base model, then test loras + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], +) +async def test_no_logprobs(server, client: openai.AsyncOpenAI, + model_name: str): + # test using token IDs + completion = await client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=None, + ) + choice = completion.choices[0] + assert choice.logprobs is None + + @pytest.mark.asyncio @pytest.mark.parametrize( # first test base model, then test loras @@ -203,7 +223,72 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI, choice = completion.choices[0] assert choice.logprobs is not None assert choice.logprobs.token_logprobs is not None - assert choice.logprobs.top_logprobs is None + assert choice.logprobs.top_logprobs is not None + assert len(choice.logprobs.top_logprobs[0]) <= 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora"], +) +async def test_some_logprobs(server, client: openai.AsyncOpenAI, + model_name: str): + # test using token IDs + completion = await client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=5, + ) + choice = completion.choices[0] + assert choice.logprobs is not None + assert choice.logprobs.token_logprobs is not None + assert choice.logprobs.top_logprobs is not None + assert len(choice.logprobs.top_logprobs[0]) <= 6 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora"], +) +async def test_too_many_completion_logprobs(server, client: openai.AsyncOpenAI, + model_name: str): + + with pytest.raises( + (openai.BadRequestError, openai.APIError)): # test using token IDs + await client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=6, + ) + ... + with pytest.raises( + (openai.BadRequestError, openai.APIError)): # test using token IDs + stream = await client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + logprobs=6, + stream=True, + ) + async for chunk in stream: + ... + + # the server should still work afterwards + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + completion = completion.choices[0].text + assert completion is not None and len(completion) >= 0 @pytest.mark.asyncio @@ -233,8 +318,10 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI, chat_completion.choices) == 1 assert chat_completion.choices[0].message is not None assert chat_completion.choices[0].logprobs is not None - assert chat_completion.choices[0].logprobs.top_logprobs is not None - assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5 + assert chat_completion.choices[0].logprobs.content[ + 0].top_logprobs is not None + assert len( + chat_completion.choices[0].logprobs.content[0].top_logprobs) == 5 message = chat_completion.choices[0].message assert message.content is not None and len(message.content) >= 10 assert message.role == "assistant" @@ -252,9 +339,13 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI, @pytest.mark.asyncio -@pytest.mark.parametrize("model_name", [MODEL_NAME]) -async def test_too_many_logprobs(server, client: openai.AsyncOpenAI, - model_name: str): +@pytest.mark.parametrize( + # first test base model, then test loras + "model_name", + [MODEL_NAME, "zephyr-lora", "zephyr-lora2"], +) +async def test_no_logprobs_chat(server, client: openai.AsyncOpenAI, + model_name: str): messages = [{ "role": "system", "content": "you are a helpful assistant" @@ -263,13 +354,92 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI, "content": "what is 1+1?" }] - # Default max_logprobs is 5, so this should raise an error + chat_completion = await client.chat.completions.create(model=model_name, + messages=messages, + max_tokens=5, + temperature=0.0, + logprobs=False) + + choice = chat_completion.choices[0] + assert choice.logprobs is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + # just test 1 lora hereafter + "model_name", + [MODEL_NAME, "zephyr-lora"], +) +async def test_zero_logprobs_chat(server, client: openai.AsyncOpenAI, + model_name: str): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": "what is 1+1?" + }] + + chat_completion = await client.chat.completions.create(model=model_name, + messages=messages, + max_tokens=5, + temperature=0.0, + logprobs=True, + top_logprobs=0) + + choice = chat_completion.choices[0] + assert choice.logprobs is not None + assert choice.logprobs.content is not None + assert len(choice.logprobs.content[0].top_logprobs) <= 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME, "zephyr-lora"], +) +async def test_some_logprobs_chat(server, client: openai.AsyncOpenAI, + model_name: str): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": "what is 1+1?" + }] + + chat_completion = await client.chat.completions.create(model=model_name, + messages=messages, + max_tokens=5, + temperature=0.0, + logprobs=True, + top_logprobs=5) + + choice = chat_completion.choices[0] + assert choice.logprobs is not None + assert choice.logprobs.content is not None + assert len(choice.logprobs.content[0].top_logprobs) <= 6 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_too_many_chat_logprobs(server, client: openai.AsyncOpenAI, + model_name: str): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": "what is 1+1?" + }] + + # Default max_logprobs is 20, so this should raise an error with pytest.raises((openai.BadRequestError, openai.APIError)): stream = await client.chat.completions.create(model=model_name, messages=messages, max_tokens=10, logprobs=True, - top_logprobs=10, + top_logprobs=21, stream=True) async for chunk in stream: ... @@ -279,25 +449,9 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI, messages=messages, max_tokens=10, logprobs=True, - top_logprobs=10, + top_logprobs=30, stream=False) - with pytest.raises((openai.BadRequestError, openai.APIError)): - stream = await client.completions.create(model=model_name, - prompt="Test", - max_tokens=10, - logprobs=10, - stream=True) - async for chunk in stream: - ... - - with pytest.raises(openai.BadRequestError): - await client.completions.create(model=model_name, - prompt="Test", - max_tokens=10, - logprobs=10, - stream=False) - # the server should still work afterwards chat_completion = await client.chat.completions.create(model=model_name, messages=messages, @@ -744,13 +898,12 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI, top_logprobs=5, extra_body=dict(guided_choice=TEST_CHOICE, guided_decoding_backend=guided_decoding_backend)) - top_logprobs = chat_completion.choices[0].logprobs.top_logprobs + top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs # -9999.0 is the minimum logprob returned by OpenAI assert all( - isinstance(logprob, float) and logprob >= -9999.0 - for token_dict in top_logprobs - for token, logprob in token_dict.items()) + isinstance(token.logprob, float) and token.logprob >= -9999.0 + for token in top_logprobs) @pytest.mark.asyncio diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index e6eae689..e380212a 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -250,6 +250,19 @@ class ChatCompletionRequest(OpenAIBaseModel): "('guided_json', 'guided_regex' or 'guided_choice').") return data + @model_validator(mode="before") + @classmethod + def check_logprobs(cls, data): + if "top_logprobs" in data and data["top_logprobs"] is not None: + if "logprobs" not in data or data["logprobs"] is False: + raise ValueError( + "when using `top_logprobs`, `logprobs` must be set to true." + ) + elif not 0 <= data["top_logprobs"] <= 20: + raise ValueError( + "`top_logprobs` must be a value in the interval [0, 20].") + return data + class CompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation @@ -396,6 +409,15 @@ class CompletionRequest(OpenAIBaseModel): "('guided_json', 'guided_regex' or 'guided_choice').") return data + @model_validator(mode="before") + @classmethod + def check_logprobs(cls, data): + if "logprobs" in data and data[ + "logprobs"] is not None and not 0 <= data["logprobs"] <= 5: + raise ValueError(("if passed, `logprobs` must be a value", + " in the interval [0, 5].")) + return data + class EmbeddingRequest(BaseModel): # Ordered by official OpenAI API documentation @@ -415,7 +437,7 @@ class EmbeddingRequest(BaseModel): return PoolingParams(additional_data=self.additional_data) -class LogProbs(OpenAIBaseModel): +class CompletionLogProbs(OpenAIBaseModel): text_offset: List[int] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list) tokens: List[str] = Field(default_factory=list) @@ -425,7 +447,7 @@ class LogProbs(OpenAIBaseModel): class CompletionResponseChoice(OpenAIBaseModel): index: int text: str - logprobs: Optional[LogProbs] = None + logprobs: Optional[CompletionLogProbs] = None finish_reason: Optional[str] = None stop_reason: Optional[Union[int, str]] = Field( default=None, @@ -448,7 +470,7 @@ class CompletionResponse(OpenAIBaseModel): class CompletionResponseStreamChoice(OpenAIBaseModel): index: int text: str - logprobs: Optional[LogProbs] = None + logprobs: Optional[CompletionLogProbs] = None finish_reason: Optional[str] = None stop_reason: Optional[Union[int, str]] = Field( default=None, @@ -488,11 +510,25 @@ class ChatMessage(OpenAIBaseModel): content: str +class ChatCompletionLogProb(OpenAIBaseModel): + token: str + logprob: float = -9999.0 + bytes: Optional[List[int]] = None + + +class ChatCompletionLogProbsContent(ChatCompletionLogProb): + top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list) + + +class ChatCompletionLogProbs(OpenAIBaseModel): + content: Optional[List[ChatCompletionLogProbsContent]] = None + + class ChatCompletionResponseChoice(OpenAIBaseModel): index: int message: ChatMessage - logprobs: Optional[LogProbs] = None - finish_reason: Optional[str] = None + logprobs: Optional[ChatCompletionLogProbs] = None + finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None stop_reason: Optional[Union[int, str]] = None @@ -513,8 +549,8 @@ class DeltaMessage(OpenAIBaseModel): class ChatCompletionResponseStreamChoice(OpenAIBaseModel): index: int delta: DeltaMessage - logprobs: Optional[LogProbs] = None - finish_reason: Optional[str] = None + logprobs: Optional[ChatCompletionLogProbs] = None + finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None stop_reason: Optional[Union[int, str]] = None diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 8cb50e33..cc5b896e 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,8 +1,10 @@ import codecs import time from dataclasses import dataclass -from typing import (AsyncGenerator, AsyncIterator, Iterable, List, Optional, - TypedDict, Union, cast, final) +from typing import (AsyncGenerator, AsyncIterator, Dict, Iterable, List, + Optional) +from typing import Sequence as GenericSequence +from typing import TypedDict, Union, cast, final from fastapi import Request from openai.types.chat import ChatCompletionContentPartTextParam @@ -10,8 +12,9 @@ from openai.types.chat import ChatCompletionContentPartTextParam from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import ( - ChatCompletionContentPartParam, ChatCompletionMessageParam, - ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionContentPartParam, ChatCompletionLogProb, + ChatCompletionLogProbs, ChatCompletionLogProbsContent, + ChatCompletionMessageParam, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, UsageInfo) @@ -21,6 +24,7 @@ from vllm.logger import init_logger from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) from vllm.outputs import RequestOutput +from vllm.sequence import Logprob from vllm.utils import random_uuid logger = init_logger(__name__) @@ -283,11 +287,10 @@ class OpenAIServingChat(OpenAIServing): previous_num_tokens[i]:] if output.logprobs else None if request.logprobs: - logprobs = self._create_logprobs( + logprobs = self._create_chat_logprobs( token_ids=delta_token_ids, top_logprobs=top_logprobs, num_output_top_logprobs=request.top_logprobs, - initial_text_offset=len(previous_texts[i]), ) else: logprobs = None @@ -370,7 +373,7 @@ class OpenAIServingChat(OpenAIServing): top_logprobs = output.logprobs if request.logprobs: - logprobs = self._create_logprobs( + logprobs = self._create_chat_logprobs( token_ids=token_ids, top_logprobs=top_logprobs, num_output_top_logprobs=request.top_logprobs, @@ -383,8 +386,7 @@ class OpenAIServingChat(OpenAIServing): message=ChatMessage(role=role, content=output.text), logprobs=logprobs, finish_reason=output.finish_reason, - stop_reason=output.stop_reason, - ) + stop_reason=output.stop_reason) choices.append(choice_data) if request.echo: @@ -414,3 +416,51 @@ class OpenAIServingChat(OpenAIServing): ) return response + + def _get_top_logprobs( + self, logprobs: Dict[int, Logprob], + top_logprobs: Optional[int]) -> List[ChatCompletionLogProb]: + return [ + ChatCompletionLogProb( + token=self._get_decoded_token(p[1], p[0]), + logprob=max(p[1].logprob, -9999.0), + bytes=list( + self._get_decoded_token(p[1], + p[0]).encode("utf-8", + errors="replace"))) + for i, p in enumerate(logprobs.items()) + if top_logprobs and i < top_logprobs + ] + + def _create_chat_logprobs( + self, + token_ids: GenericSequence[int], + top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], + num_output_top_logprobs: Optional[int] = None, + ) -> ChatCompletionLogProbs: + """Create OpenAI-style logprobs.""" + + logprobs_content = [] + + for i, token_id in enumerate(token_ids): + step_top_logprobs = top_logprobs[i] + if step_top_logprobs is None: + logprobs_content.append( + ChatCompletionLogProbsContent( + token=self.tokenizer.decode(token_id), + bytes=list( + self.tokenizer.decode(token_id).encode( + "utf-8", errors="replace")))) + else: + logprobs_content.append( + ChatCompletionLogProbsContent( + token=step_top_logprobs[token_id].decoded_token, + logprob=max(step_top_logprobs[token_id].logprob, + -9999.0), + bytes=list( + step_top_logprobs[token_id].decoded_token.encode( + "utf-8", errors="replace")), + top_logprobs=self._get_top_logprobs( + step_top_logprobs, num_output_top_logprobs))) + + return ChatCompletionLogProbs(content=logprobs_content) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index d1812c8f..2fb122ed 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -1,23 +1,29 @@ import time from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List, - Optional, Tuple) + Optional) +from typing import Sequence as GenericSequence +from typing import Tuple from fastapi import Request from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.entrypoints.openai.protocol import (CompletionRequest, +# yapf: disable +from vllm.entrypoints.openai.protocol import (CompletionLogProbs, + CompletionRequest, CompletionResponse, CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, - LogProbs, UsageInfo) + UsageInfo) +# yapf: enable from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing) from vllm.logger import init_logger from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) from vllm.outputs import RequestOutput +from vllm.sequence import Logprob from vllm.utils import merge_async_iterators, random_uuid logger = init_logger(__name__) @@ -25,7 +31,7 @@ logger = init_logger(__name__) TypeTokenIDs = List[int] TypeTopLogProbs = List[Optional[Dict[int, float]]] TypeCreateLogProbsFn = Callable[ - [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs] + [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs] def parse_prompt_format(prompt) -> Tuple[bool, list]: @@ -235,7 +241,7 @@ class OpenAIServingCompletion(OpenAIServing): i]:] if output.logprobs else None if request.logprobs is not None: - logprobs = self._create_logprobs( + logprobs = self._create_completion_logprobs( token_ids=delta_token_ids, top_logprobs=top_logprobs, num_output_top_logprobs=request.logprobs, @@ -317,7 +323,7 @@ class OpenAIServingCompletion(OpenAIServing): assert top_logprobs is not None, ( "top_logprobs must be provided when logprobs " "is requested") - logprobs = self._create_logprobs( + logprobs = self._create_completion_logprobs( token_ids=token_ids, top_logprobs=top_logprobs, num_output_top_logprobs=request.logprobs, @@ -351,3 +357,59 @@ class OpenAIServingCompletion(OpenAIServing): choices=choices, usage=usage, ) + + def _create_completion_logprobs( + self, + token_ids: GenericSequence[int], + top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], + num_output_top_logprobs: int, + initial_text_offset: int = 0, + ) -> CompletionLogProbs: + """Create logprobs for OpenAI Completion API.""" + out_text_offset: List[int] = [] + out_token_logprobs: List[Optional[float]] = [] + out_tokens: List[str] = [] + out_top_logprobs: List[Optional[Dict[str, float]]] = [] + + last_token_len = 0 + + for i, token_id in enumerate(token_ids): + step_top_logprobs = top_logprobs[i] + if step_top_logprobs is None: + token = self.tokenizer.decode(token_id) + out_tokens.append(token) + out_token_logprobs.append(None) + out_top_logprobs.append(None) + else: + token = self._get_decoded_token(step_top_logprobs[token_id], + token_id) + token_logprob = max(step_top_logprobs[token_id].logprob, + -9999.0) + out_tokens.append(token) + out_token_logprobs.append(token_logprob) + + # makes sure to add the top num_output_top_logprobs + 1 + # logprobs, as defined in the openai API + # (cf. https://github.com/openai/openai-openapi/blob/ + # 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153) + out_top_logprobs.append({ + # Convert float("-inf") to the + # JSON-serializable float that OpenAI uses + self._get_decoded_token(top_lp[1], top_lp[0]): + max(top_lp[1].logprob, -9999.0) + for i, top_lp in enumerate(step_top_logprobs.items()) + if num_output_top_logprobs >= i + }) + + if len(out_text_offset) == 0: + out_text_offset.append(initial_text_offset) + else: + out_text_offset.append(out_text_offset[-1] + last_token_len) + last_token_len = len(token) + + return CompletionLogProbs( + text_offset=out_text_offset, + token_logprobs=out_token_logprobs, + tokens=out_tokens, + top_logprobs=out_top_logprobs, + ) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 708b0dad..066acdf1 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -11,7 +11,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest, EmbeddingRequest, ErrorResponse, - LogProbs, ModelCard, ModelList, + ModelCard, ModelList, ModelPermission) from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -75,51 +75,6 @@ class OpenAIServing: model_cards.extend(lora_cards) return ModelList(data=model_cards) - def _create_logprobs( - self, - token_ids: List[int], - top_logprobs: List[Optional[Dict[int, Logprob]]], - num_output_top_logprobs: Optional[int] = None, - initial_text_offset: int = 0, - ) -> LogProbs: - """Create OpenAI-style logprobs.""" - logprobs = LogProbs() - last_token_len = 0 - if num_output_top_logprobs: - logprobs.top_logprobs = [] - - for i, token_id in enumerate(token_ids): - step_top_logprobs = top_logprobs[i] - if step_top_logprobs is None: - token = self.tokenizer.decode(token_id) - logprobs.tokens.append(token) - logprobs.token_logprobs.append(None) - assert logprobs.top_logprobs is not None - logprobs.top_logprobs.append(None) - else: - token_logprob = step_top_logprobs[token_id].logprob - token = step_top_logprobs[token_id].decoded_token - logprobs.tokens.append(token) - token_logprob = max(token_logprob, -9999.0) - logprobs.token_logprobs.append(token_logprob) - - if num_output_top_logprobs: - assert logprobs.top_logprobs is not None - logprobs.top_logprobs.append({ - # Convert float("-inf") to the - # JSON-serializable float that OpenAI uses - p.decoded_token: max(p.logprob, -9999.0) - for i, p in step_top_logprobs.items() - } if step_top_logprobs else None) - - if len(logprobs.text_offset) == 0: - logprobs.text_offset.append(initial_text_offset) - else: - logprobs.text_offset.append(logprobs.text_offset[-1] + - last_token_len) - last_token_len = len(token) - return logprobs - def create_error_response( self, message: str, @@ -235,3 +190,8 @@ class OpenAIServing: f"Please reduce the length of the messages or completion.", ) else: return input_ids, input_text + + def _get_decoded_token(self, logprob: Logprob, token_id: int) -> str: + if logprob.decoded_token is not None: + return logprob.decoded_token + return self.tokenizer.decode(token_id)