[BUGFIX] [FRONTEND] Correct chat logprobs (#5029)
Co-authored-by: Breno Faria <breno.faria@intrafind.com>
This commit is contained in:
parent
e07aff9e52
commit
87d41c849d
@ -94,8 +94,10 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI):
|
||||
chat_completion.choices) == 1
|
||||
assert chat_completion.choices[0].message is not None
|
||||
assert chat_completion.choices[0].logprobs is not None
|
||||
assert chat_completion.choices[0].logprobs.top_logprobs is not None
|
||||
assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5
|
||||
assert chat_completion.choices[0].logprobs.content[
|
||||
0].top_logprobs is not None
|
||||
assert len(
|
||||
chat_completion.choices[0].logprobs.content[0].top_logprobs) == 5
|
||||
message = chat_completion.choices[0].message
|
||||
assert message.content is not None and len(message.content) >= 10
|
||||
assert message.role == "assistant"
|
||||
|
@ -184,6 +184,26 @@ async def test_single_completion(server, client: openai.AsyncOpenAI,
|
||||
completion.choices[0].text) >= 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
||||
)
|
||||
async def test_no_logprobs(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
logprobs=None,
|
||||
)
|
||||
choice = completion.choices[0]
|
||||
assert choice.logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras
|
||||
@ -203,7 +223,72 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI,
|
||||
choice = completion.choices[0]
|
||||
assert choice.logprobs is not None
|
||||
assert choice.logprobs.token_logprobs is not None
|
||||
assert choice.logprobs.top_logprobs is None
|
||||
assert choice.logprobs.top_logprobs is not None
|
||||
assert len(choice.logprobs.top_logprobs[0]) <= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_some_logprobs(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
logprobs=5,
|
||||
)
|
||||
choice = completion.choices[0]
|
||||
assert choice.logprobs is not None
|
||||
assert choice.logprobs.token_logprobs is not None
|
||||
assert choice.logprobs.top_logprobs is not None
|
||||
assert len(choice.logprobs.top_logprobs[0]) <= 6
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_too_many_completion_logprobs(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
|
||||
with pytest.raises(
|
||||
(openai.BadRequestError, openai.APIError)): # test using token IDs
|
||||
await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
logprobs=6,
|
||||
)
|
||||
...
|
||||
with pytest.raises(
|
||||
(openai.BadRequestError, openai.APIError)): # test using token IDs
|
||||
stream = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
logprobs=6,
|
||||
stream=True,
|
||||
)
|
||||
async for chunk in stream:
|
||||
...
|
||||
|
||||
# the server should still work afterwards
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
completion = completion.choices[0].text
|
||||
assert completion is not None and len(completion) >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -233,8 +318,10 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
|
||||
chat_completion.choices) == 1
|
||||
assert chat_completion.choices[0].message is not None
|
||||
assert chat_completion.choices[0].logprobs is not None
|
||||
assert chat_completion.choices[0].logprobs.top_logprobs is not None
|
||||
assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5
|
||||
assert chat_completion.choices[0].logprobs.content[
|
||||
0].top_logprobs is not None
|
||||
assert len(
|
||||
chat_completion.choices[0].logprobs.content[0].top_logprobs) == 5
|
||||
message = chat_completion.choices[0].message
|
||||
assert message.content is not None and len(message.content) >= 10
|
||||
assert message.role == "assistant"
|
||||
@ -252,8 +339,12 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
||||
)
|
||||
async def test_no_logprobs_chat(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
@ -263,13 +354,92 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
|
||||
"content": "what is 1+1?"
|
||||
}]
|
||||
|
||||
# Default max_logprobs is 5, so this should raise an error
|
||||
chat_completion = await client.chat.completions.create(model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
logprobs=False)
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# just test 1 lora hereafter
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_zero_logprobs_chat(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}]
|
||||
|
||||
chat_completion = await client.chat.completions.create(model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
logprobs=True,
|
||||
top_logprobs=0)
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.logprobs is not None
|
||||
assert choice.logprobs.content is not None
|
||||
assert len(choice.logprobs.content[0].top_logprobs) <= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_some_logprobs_chat(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}]
|
||||
|
||||
chat_completion = await client.chat.completions.create(model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
logprobs=True,
|
||||
top_logprobs=5)
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.logprobs is not None
|
||||
assert choice.logprobs.content is not None
|
||||
assert len(choice.logprobs.content[0].top_logprobs) <= 6
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_too_many_chat_logprobs(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}]
|
||||
|
||||
# Default max_logprobs is 20, so this should raise an error
|
||||
with pytest.raises((openai.BadRequestError, openai.APIError)):
|
||||
stream = await client.chat.completions.create(model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
logprobs=True,
|
||||
top_logprobs=10,
|
||||
top_logprobs=21,
|
||||
stream=True)
|
||||
async for chunk in stream:
|
||||
...
|
||||
@ -279,23 +449,7 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
logprobs=True,
|
||||
top_logprobs=10,
|
||||
stream=False)
|
||||
|
||||
with pytest.raises((openai.BadRequestError, openai.APIError)):
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt="Test",
|
||||
max_tokens=10,
|
||||
logprobs=10,
|
||||
stream=True)
|
||||
async for chunk in stream:
|
||||
...
|
||||
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
await client.completions.create(model=model_name,
|
||||
prompt="Test",
|
||||
max_tokens=10,
|
||||
logprobs=10,
|
||||
top_logprobs=30,
|
||||
stream=False)
|
||||
|
||||
# the server should still work afterwards
|
||||
@ -744,13 +898,12 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
|
||||
top_logprobs=5,
|
||||
extra_body=dict(guided_choice=TEST_CHOICE,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
top_logprobs = chat_completion.choices[0].logprobs.top_logprobs
|
||||
top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs
|
||||
|
||||
# -9999.0 is the minimum logprob returned by OpenAI
|
||||
assert all(
|
||||
isinstance(logprob, float) and logprob >= -9999.0
|
||||
for token_dict in top_logprobs
|
||||
for token, logprob in token_dict.items())
|
||||
isinstance(token.logprob, float) and token.logprob >= -9999.0
|
||||
for token in top_logprobs)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -250,6 +250,19 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
"('guided_json', 'guided_regex' or 'guided_choice').")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_logprobs(cls, data):
|
||||
if "top_logprobs" in data and data["top_logprobs"] is not None:
|
||||
if "logprobs" not in data or data["logprobs"] is False:
|
||||
raise ValueError(
|
||||
"when using `top_logprobs`, `logprobs` must be set to true."
|
||||
)
|
||||
elif not 0 <= data["top_logprobs"] <= 20:
|
||||
raise ValueError(
|
||||
"`top_logprobs` must be a value in the interval [0, 20].")
|
||||
return data
|
||||
|
||||
|
||||
class CompletionRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
@ -396,6 +409,15 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
"('guided_json', 'guided_regex' or 'guided_choice').")
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_logprobs(cls, data):
|
||||
if "logprobs" in data and data[
|
||||
"logprobs"] is not None and not 0 <= data["logprobs"] <= 5:
|
||||
raise ValueError(("if passed, `logprobs` must be a value",
|
||||
" in the interval [0, 5]."))
|
||||
return data
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
@ -415,7 +437,7 @@ class EmbeddingRequest(BaseModel):
|
||||
return PoolingParams(additional_data=self.additional_data)
|
||||
|
||||
|
||||
class LogProbs(OpenAIBaseModel):
|
||||
class CompletionLogProbs(OpenAIBaseModel):
|
||||
text_offset: List[int] = Field(default_factory=list)
|
||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||
tokens: List[str] = Field(default_factory=list)
|
||||
@ -425,7 +447,7 @@ class LogProbs(OpenAIBaseModel):
|
||||
class CompletionResponseChoice(OpenAIBaseModel):
|
||||
index: int
|
||||
text: str
|
||||
logprobs: Optional[LogProbs] = None
|
||||
logprobs: Optional[CompletionLogProbs] = None
|
||||
finish_reason: Optional[str] = None
|
||||
stop_reason: Optional[Union[int, str]] = Field(
|
||||
default=None,
|
||||
@ -448,7 +470,7 @@ class CompletionResponse(OpenAIBaseModel):
|
||||
class CompletionResponseStreamChoice(OpenAIBaseModel):
|
||||
index: int
|
||||
text: str
|
||||
logprobs: Optional[LogProbs] = None
|
||||
logprobs: Optional[CompletionLogProbs] = None
|
||||
finish_reason: Optional[str] = None
|
||||
stop_reason: Optional[Union[int, str]] = Field(
|
||||
default=None,
|
||||
@ -488,11 +510,25 @@ class ChatMessage(OpenAIBaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class ChatCompletionLogProb(OpenAIBaseModel):
|
||||
token: str
|
||||
logprob: float = -9999.0
|
||||
bytes: Optional[List[int]] = None
|
||||
|
||||
|
||||
class ChatCompletionLogProbsContent(ChatCompletionLogProb):
|
||||
top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ChatCompletionLogProbs(OpenAIBaseModel):
|
||||
content: Optional[List[ChatCompletionLogProbsContent]] = None
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(OpenAIBaseModel):
|
||||
index: int
|
||||
message: ChatMessage
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Optional[str] = None
|
||||
logprobs: Optional[ChatCompletionLogProbs] = None
|
||||
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
|
||||
stop_reason: Optional[Union[int, str]] = None
|
||||
|
||||
|
||||
@ -513,8 +549,8 @@ class DeltaMessage(OpenAIBaseModel):
|
||||
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
|
||||
index: int
|
||||
delta: DeltaMessage
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Optional[str] = None
|
||||
logprobs: Optional[ChatCompletionLogProbs] = None
|
||||
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
|
||||
stop_reason: Optional[Union[int, str]] = None
|
||||
|
||||
|
||||
|
@ -1,8 +1,10 @@
|
||||
import codecs
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import (AsyncGenerator, AsyncIterator, Iterable, List, Optional,
|
||||
TypedDict, Union, cast, final)
|
||||
from typing import (AsyncGenerator, AsyncIterator, Dict, Iterable, List,
|
||||
Optional)
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import TypedDict, Union, cast, final
|
||||
|
||||
from fastapi import Request
|
||||
from openai.types.chat import ChatCompletionContentPartTextParam
|
||||
@ -10,8 +12,9 @@ from openai.types.chat import ChatCompletionContentPartTextParam
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionContentPartParam, ChatCompletionMessageParam,
|
||||
ChatCompletionRequest, ChatCompletionResponse,
|
||||
ChatCompletionContentPartParam, ChatCompletionLogProb,
|
||||
ChatCompletionLogProbs, ChatCompletionLogProbsContent,
|
||||
ChatCompletionMessageParam, ChatCompletionRequest, ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
||||
UsageInfo)
|
||||
@ -21,6 +24,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
get_guided_decoding_logits_processor)
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -283,11 +287,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
previous_num_tokens[i]:] if output.logprobs else None
|
||||
|
||||
if request.logprobs:
|
||||
logprobs = self._create_logprobs(
|
||||
logprobs = self._create_chat_logprobs(
|
||||
token_ids=delta_token_ids,
|
||||
top_logprobs=top_logprobs,
|
||||
num_output_top_logprobs=request.top_logprobs,
|
||||
initial_text_offset=len(previous_texts[i]),
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
@ -370,7 +373,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
top_logprobs = output.logprobs
|
||||
|
||||
if request.logprobs:
|
||||
logprobs = self._create_logprobs(
|
||||
logprobs = self._create_chat_logprobs(
|
||||
token_ids=token_ids,
|
||||
top_logprobs=top_logprobs,
|
||||
num_output_top_logprobs=request.top_logprobs,
|
||||
@ -383,8 +386,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
message=ChatMessage(role=role, content=output.text),
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason,
|
||||
stop_reason=output.stop_reason,
|
||||
)
|
||||
stop_reason=output.stop_reason)
|
||||
choices.append(choice_data)
|
||||
|
||||
if request.echo:
|
||||
@ -414,3 +416,51 @@ class OpenAIServingChat(OpenAIServing):
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _get_top_logprobs(
|
||||
self, logprobs: Dict[int, Logprob],
|
||||
top_logprobs: Optional[int]) -> List[ChatCompletionLogProb]:
|
||||
return [
|
||||
ChatCompletionLogProb(
|
||||
token=self._get_decoded_token(p[1], p[0]),
|
||||
logprob=max(p[1].logprob, -9999.0),
|
||||
bytes=list(
|
||||
self._get_decoded_token(p[1],
|
||||
p[0]).encode("utf-8",
|
||||
errors="replace")))
|
||||
for i, p in enumerate(logprobs.items())
|
||||
if top_logprobs and i < top_logprobs
|
||||
]
|
||||
|
||||
def _create_chat_logprobs(
|
||||
self,
|
||||
token_ids: GenericSequence[int],
|
||||
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
||||
num_output_top_logprobs: Optional[int] = None,
|
||||
) -> ChatCompletionLogProbs:
|
||||
"""Create OpenAI-style logprobs."""
|
||||
|
||||
logprobs_content = []
|
||||
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is None:
|
||||
logprobs_content.append(
|
||||
ChatCompletionLogProbsContent(
|
||||
token=self.tokenizer.decode(token_id),
|
||||
bytes=list(
|
||||
self.tokenizer.decode(token_id).encode(
|
||||
"utf-8", errors="replace"))))
|
||||
else:
|
||||
logprobs_content.append(
|
||||
ChatCompletionLogProbsContent(
|
||||
token=step_top_logprobs[token_id].decoded_token,
|
||||
logprob=max(step_top_logprobs[token_id].logprob,
|
||||
-9999.0),
|
||||
bytes=list(
|
||||
step_top_logprobs[token_id].decoded_token.encode(
|
||||
"utf-8", errors="replace")),
|
||||
top_logprobs=self._get_top_logprobs(
|
||||
step_top_logprobs, num_output_top_logprobs)))
|
||||
|
||||
return ChatCompletionLogProbs(content=logprobs_content)
|
||||
|
@ -1,23 +1,29 @@
|
||||
import time
|
||||
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
|
||||
Optional, Tuple)
|
||||
Optional)
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Tuple
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.protocol import (CompletionRequest,
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse,
|
||||
LogProbs, UsageInfo)
|
||||
UsageInfo)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
OpenAIServing)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
get_guided_decoding_logits_processor)
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.utils import merge_async_iterators, random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -25,7 +31,7 @@ logger = init_logger(__name__)
|
||||
TypeTokenIDs = List[int]
|
||||
TypeTopLogProbs = List[Optional[Dict[int, float]]]
|
||||
TypeCreateLogProbsFn = Callable[
|
||||
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs]
|
||||
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
|
||||
|
||||
|
||||
def parse_prompt_format(prompt) -> Tuple[bool, list]:
|
||||
@ -235,7 +241,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
i]:] if output.logprobs else None
|
||||
|
||||
if request.logprobs is not None:
|
||||
logprobs = self._create_logprobs(
|
||||
logprobs = self._create_completion_logprobs(
|
||||
token_ids=delta_token_ids,
|
||||
top_logprobs=top_logprobs,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
@ -317,7 +323,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
assert top_logprobs is not None, (
|
||||
"top_logprobs must be provided when logprobs "
|
||||
"is requested")
|
||||
logprobs = self._create_logprobs(
|
||||
logprobs = self._create_completion_logprobs(
|
||||
token_ids=token_ids,
|
||||
top_logprobs=top_logprobs,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
@ -351,3 +357,59 @@ class OpenAIServingCompletion(OpenAIServing):
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _create_completion_logprobs(
|
||||
self,
|
||||
token_ids: GenericSequence[int],
|
||||
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
||||
num_output_top_logprobs: int,
|
||||
initial_text_offset: int = 0,
|
||||
) -> CompletionLogProbs:
|
||||
"""Create logprobs for OpenAI Completion API."""
|
||||
out_text_offset: List[int] = []
|
||||
out_token_logprobs: List[Optional[float]] = []
|
||||
out_tokens: List[str] = []
|
||||
out_top_logprobs: List[Optional[Dict[str, float]]] = []
|
||||
|
||||
last_token_len = 0
|
||||
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is None:
|
||||
token = self.tokenizer.decode(token_id)
|
||||
out_tokens.append(token)
|
||||
out_token_logprobs.append(None)
|
||||
out_top_logprobs.append(None)
|
||||
else:
|
||||
token = self._get_decoded_token(step_top_logprobs[token_id],
|
||||
token_id)
|
||||
token_logprob = max(step_top_logprobs[token_id].logprob,
|
||||
-9999.0)
|
||||
out_tokens.append(token)
|
||||
out_token_logprobs.append(token_logprob)
|
||||
|
||||
# makes sure to add the top num_output_top_logprobs + 1
|
||||
# logprobs, as defined in the openai API
|
||||
# (cf. https://github.com/openai/openai-openapi/blob/
|
||||
# 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
|
||||
out_top_logprobs.append({
|
||||
# Convert float("-inf") to the
|
||||
# JSON-serializable float that OpenAI uses
|
||||
self._get_decoded_token(top_lp[1], top_lp[0]):
|
||||
max(top_lp[1].logprob, -9999.0)
|
||||
for i, top_lp in enumerate(step_top_logprobs.items())
|
||||
if num_output_top_logprobs >= i
|
||||
})
|
||||
|
||||
if len(out_text_offset) == 0:
|
||||
out_text_offset.append(initial_text_offset)
|
||||
else:
|
||||
out_text_offset.append(out_text_offset[-1] + last_token_len)
|
||||
last_token_len = len(token)
|
||||
|
||||
return CompletionLogProbs(
|
||||
text_offset=out_text_offset,
|
||||
token_logprobs=out_token_logprobs,
|
||||
tokens=out_tokens,
|
||||
top_logprobs=out_top_logprobs,
|
||||
)
|
||||
|
@ -11,7 +11,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
EmbeddingRequest, ErrorResponse,
|
||||
LogProbs, ModelCard, ModelList,
|
||||
ModelCard, ModelList,
|
||||
ModelPermission)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@ -75,51 +75,6 @@ class OpenAIServing:
|
||||
model_cards.extend(lora_cards)
|
||||
return ModelList(data=model_cards)
|
||||
|
||||
def _create_logprobs(
|
||||
self,
|
||||
token_ids: List[int],
|
||||
top_logprobs: List[Optional[Dict[int, Logprob]]],
|
||||
num_output_top_logprobs: Optional[int] = None,
|
||||
initial_text_offset: int = 0,
|
||||
) -> LogProbs:
|
||||
"""Create OpenAI-style logprobs."""
|
||||
logprobs = LogProbs()
|
||||
last_token_len = 0
|
||||
if num_output_top_logprobs:
|
||||
logprobs.top_logprobs = []
|
||||
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is None:
|
||||
token = self.tokenizer.decode(token_id)
|
||||
logprobs.tokens.append(token)
|
||||
logprobs.token_logprobs.append(None)
|
||||
assert logprobs.top_logprobs is not None
|
||||
logprobs.top_logprobs.append(None)
|
||||
else:
|
||||
token_logprob = step_top_logprobs[token_id].logprob
|
||||
token = step_top_logprobs[token_id].decoded_token
|
||||
logprobs.tokens.append(token)
|
||||
token_logprob = max(token_logprob, -9999.0)
|
||||
logprobs.token_logprobs.append(token_logprob)
|
||||
|
||||
if num_output_top_logprobs:
|
||||
assert logprobs.top_logprobs is not None
|
||||
logprobs.top_logprobs.append({
|
||||
# Convert float("-inf") to the
|
||||
# JSON-serializable float that OpenAI uses
|
||||
p.decoded_token: max(p.logprob, -9999.0)
|
||||
for i, p in step_top_logprobs.items()
|
||||
} if step_top_logprobs else None)
|
||||
|
||||
if len(logprobs.text_offset) == 0:
|
||||
logprobs.text_offset.append(initial_text_offset)
|
||||
else:
|
||||
logprobs.text_offset.append(logprobs.text_offset[-1] +
|
||||
last_token_len)
|
||||
last_token_len = len(token)
|
||||
return logprobs
|
||||
|
||||
def create_error_response(
|
||||
self,
|
||||
message: str,
|
||||
@ -235,3 +190,8 @@ class OpenAIServing:
|
||||
f"Please reduce the length of the messages or completion.", )
|
||||
else:
|
||||
return input_ids, input_text
|
||||
|
||||
def _get_decoded_token(self, logprob: Logprob, token_id: int) -> str:
|
||||
if logprob.decoded_token is not None:
|
||||
return logprob.decoded_token
|
||||
return self.tokenizer.decode(token_id)
|
||||
|
Loading…
x
Reference in New Issue
Block a user