[BUGFIX] [FRONTEND] Correct chat logprobs (#5029)

Co-authored-by: Breno Faria <breno.faria@intrafind.com>
This commit is contained in:
Breno Faria 2024-05-30 11:52:14 +02:00 committed by GitHub
parent e07aff9e52
commit 87d41c849d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 362 additions and 99 deletions

View File

@ -94,8 +94,10 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI):
chat_completion.choices) == 1
assert chat_completion.choices[0].message is not None
assert chat_completion.choices[0].logprobs is not None
assert chat_completion.choices[0].logprobs.top_logprobs is not None
assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5
assert chat_completion.choices[0].logprobs.content[
0].top_logprobs is not None
assert len(
chat_completion.choices[0].logprobs.content[0].top_logprobs) == 5
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"

View File

@ -184,6 +184,26 @@ async def test_single_completion(server, client: openai.AsyncOpenAI,
completion.choices[0].text) >= 5
@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
)
async def test_no_logprobs(server, client: openai.AsyncOpenAI,
model_name: str):
# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=None,
)
choice = completion.choices[0]
assert choice.logprobs is None
@pytest.mark.asyncio
@pytest.mark.parametrize(
# first test base model, then test loras
@ -203,7 +223,72 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI,
choice = completion.choices[0]
assert choice.logprobs is not None
assert choice.logprobs.token_logprobs is not None
assert choice.logprobs.top_logprobs is None
assert choice.logprobs.top_logprobs is not None
assert len(choice.logprobs.top_logprobs[0]) <= 1
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
async def test_some_logprobs(server, client: openai.AsyncOpenAI,
model_name: str):
# test using token IDs
completion = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=5,
)
choice = completion.choices[0]
assert choice.logprobs is not None
assert choice.logprobs.token_logprobs is not None
assert choice.logprobs.top_logprobs is not None
assert len(choice.logprobs.top_logprobs[0]) <= 6
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
async def test_too_many_completion_logprobs(server, client: openai.AsyncOpenAI,
model_name: str):
with pytest.raises(
(openai.BadRequestError, openai.APIError)): # test using token IDs
await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=6,
)
...
with pytest.raises(
(openai.BadRequestError, openai.APIError)): # test using token IDs
stream = await client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
logprobs=6,
stream=True,
)
async for chunk in stream:
...
# the server should still work afterwards
completion = await client.completions.create(
model=model_name,
prompt=[0, 0, 0, 0, 0],
max_tokens=5,
temperature=0.0,
)
completion = completion.choices[0].text
assert completion is not None and len(completion) >= 0
@pytest.mark.asyncio
@ -233,8 +318,10 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
chat_completion.choices) == 1
assert chat_completion.choices[0].message is not None
assert chat_completion.choices[0].logprobs is not None
assert chat_completion.choices[0].logprobs.top_logprobs is not None
assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5
assert chat_completion.choices[0].logprobs.content[
0].top_logprobs is not None
assert len(
chat_completion.choices[0].logprobs.content[0].top_logprobs) == 5
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"
@ -252,9 +339,13 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
model_name: str):
@pytest.mark.parametrize(
# first test base model, then test loras
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
)
async def test_no_logprobs_chat(server, client: openai.AsyncOpenAI,
model_name: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
@ -263,13 +354,92 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
"content": "what is 1+1?"
}]
# Default max_logprobs is 5, so this should raise an error
chat_completion = await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=5,
temperature=0.0,
logprobs=False)
choice = chat_completion.choices[0]
assert choice.logprobs is None
@pytest.mark.asyncio
@pytest.mark.parametrize(
# just test 1 lora hereafter
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
async def test_zero_logprobs_chat(server, client: openai.AsyncOpenAI,
model_name: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role": "user",
"content": "what is 1+1?"
}]
chat_completion = await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=5,
temperature=0.0,
logprobs=True,
top_logprobs=0)
choice = chat_completion.choices[0]
assert choice.logprobs is not None
assert choice.logprobs.content is not None
assert len(choice.logprobs.content[0].top_logprobs) <= 1
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora"],
)
async def test_some_logprobs_chat(server, client: openai.AsyncOpenAI,
model_name: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role": "user",
"content": "what is 1+1?"
}]
chat_completion = await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=5,
temperature=0.0,
logprobs=True,
top_logprobs=5)
choice = chat_completion.choices[0]
assert choice.logprobs is not None
assert choice.logprobs.content is not None
assert len(choice.logprobs.content[0].top_logprobs) <= 6
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_too_many_chat_logprobs(server, client: openai.AsyncOpenAI,
model_name: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role": "user",
"content": "what is 1+1?"
}]
# Default max_logprobs is 20, so this should raise an error
with pytest.raises((openai.BadRequestError, openai.APIError)):
stream = await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=10,
top_logprobs=21,
stream=True)
async for chunk in stream:
...
@ -279,25 +449,9 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=10,
top_logprobs=30,
stream=False)
with pytest.raises((openai.BadRequestError, openai.APIError)):
stream = await client.completions.create(model=model_name,
prompt="Test",
max_tokens=10,
logprobs=10,
stream=True)
async for chunk in stream:
...
with pytest.raises(openai.BadRequestError):
await client.completions.create(model=model_name,
prompt="Test",
max_tokens=10,
logprobs=10,
stream=False)
# the server should still work afterwards
chat_completion = await client.chat.completions.create(model=model_name,
messages=messages,
@ -744,13 +898,12 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
top_logprobs=5,
extra_body=dict(guided_choice=TEST_CHOICE,
guided_decoding_backend=guided_decoding_backend))
top_logprobs = chat_completion.choices[0].logprobs.top_logprobs
top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs
# -9999.0 is the minimum logprob returned by OpenAI
assert all(
isinstance(logprob, float) and logprob >= -9999.0
for token_dict in top_logprobs
for token, logprob in token_dict.items())
isinstance(token.logprob, float) and token.logprob >= -9999.0
for token in top_logprobs)
@pytest.mark.asyncio

View File

@ -250,6 +250,19 @@ class ChatCompletionRequest(OpenAIBaseModel):
"('guided_json', 'guided_regex' or 'guided_choice').")
return data
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if "top_logprobs" in data and data["top_logprobs"] is not None:
if "logprobs" not in data or data["logprobs"] is False:
raise ValueError(
"when using `top_logprobs`, `logprobs` must be set to true."
)
elif not 0 <= data["top_logprobs"] <= 20:
raise ValueError(
"`top_logprobs` must be a value in the interval [0, 20].")
return data
class CompletionRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
@ -396,6 +409,15 @@ class CompletionRequest(OpenAIBaseModel):
"('guided_json', 'guided_regex' or 'guided_choice').")
return data
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if "logprobs" in data and data[
"logprobs"] is not None and not 0 <= data["logprobs"] <= 5:
raise ValueError(("if passed, `logprobs` must be a value",
" in the interval [0, 5]."))
return data
class EmbeddingRequest(BaseModel):
# Ordered by official OpenAI API documentation
@ -415,7 +437,7 @@ class EmbeddingRequest(BaseModel):
return PoolingParams(additional_data=self.additional_data)
class LogProbs(OpenAIBaseModel):
class CompletionLogProbs(OpenAIBaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
@ -425,7 +447,7 @@ class LogProbs(OpenAIBaseModel):
class CompletionResponseChoice(OpenAIBaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
logprobs: Optional[CompletionLogProbs] = None
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = Field(
default=None,
@ -448,7 +470,7 @@ class CompletionResponse(OpenAIBaseModel):
class CompletionResponseStreamChoice(OpenAIBaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
logprobs: Optional[CompletionLogProbs] = None
finish_reason: Optional[str] = None
stop_reason: Optional[Union[int, str]] = Field(
default=None,
@ -488,11 +510,25 @@ class ChatMessage(OpenAIBaseModel):
content: str
class ChatCompletionLogProb(OpenAIBaseModel):
token: str
logprob: float = -9999.0
bytes: Optional[List[int]] = None
class ChatCompletionLogProbsContent(ChatCompletionLogProb):
top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
class ChatCompletionLogProbs(OpenAIBaseModel):
content: Optional[List[ChatCompletionLogProbsContent]] = None
class ChatCompletionResponseChoice(OpenAIBaseModel):
index: int
message: ChatMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[str] = None
logprobs: Optional[ChatCompletionLogProbs] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
stop_reason: Optional[Union[int, str]] = None
@ -513,8 +549,8 @@ class DeltaMessage(OpenAIBaseModel):
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
index: int
delta: DeltaMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[str] = None
logprobs: Optional[ChatCompletionLogProbs] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
stop_reason: Optional[Union[int, str]] = None

View File

@ -1,8 +1,10 @@
import codecs
import time
from dataclasses import dataclass
from typing import (AsyncGenerator, AsyncIterator, Iterable, List, Optional,
TypedDict, Union, cast, final)
from typing import (AsyncGenerator, AsyncIterator, Dict, Iterable, List,
Optional)
from typing import Sequence as GenericSequence
from typing import TypedDict, Union, cast, final
from fastapi import Request
from openai.types.chat import ChatCompletionContentPartTextParam
@ -10,8 +12,9 @@ from openai.types.chat import ChatCompletionContentPartTextParam
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (
ChatCompletionContentPartParam, ChatCompletionMessageParam,
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionContentPartParam, ChatCompletionLogProb,
ChatCompletionLogProbs, ChatCompletionLogProbsContent,
ChatCompletionMessageParam, ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
UsageInfo)
@ -21,6 +24,7 @@ from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.outputs import RequestOutput
from vllm.sequence import Logprob
from vllm.utils import random_uuid
logger = init_logger(__name__)
@ -283,11 +287,10 @@ class OpenAIServingChat(OpenAIServing):
previous_num_tokens[i]:] if output.logprobs else None
if request.logprobs:
logprobs = self._create_logprobs(
logprobs = self._create_chat_logprobs(
token_ids=delta_token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.top_logprobs,
initial_text_offset=len(previous_texts[i]),
)
else:
logprobs = None
@ -370,7 +373,7 @@ class OpenAIServingChat(OpenAIServing):
top_logprobs = output.logprobs
if request.logprobs:
logprobs = self._create_logprobs(
logprobs = self._create_chat_logprobs(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.top_logprobs,
@ -383,8 +386,7 @@ class OpenAIServingChat(OpenAIServing):
message=ChatMessage(role=role, content=output.text),
logprobs=logprobs,
finish_reason=output.finish_reason,
stop_reason=output.stop_reason,
)
stop_reason=output.stop_reason)
choices.append(choice_data)
if request.echo:
@ -414,3 +416,51 @@ class OpenAIServingChat(OpenAIServing):
)
return response
def _get_top_logprobs(
self, logprobs: Dict[int, Logprob],
top_logprobs: Optional[int]) -> List[ChatCompletionLogProb]:
return [
ChatCompletionLogProb(
token=self._get_decoded_token(p[1], p[0]),
logprob=max(p[1].logprob, -9999.0),
bytes=list(
self._get_decoded_token(p[1],
p[0]).encode("utf-8",
errors="replace")))
for i, p in enumerate(logprobs.items())
if top_logprobs and i < top_logprobs
]
def _create_chat_logprobs(
self,
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
num_output_top_logprobs: Optional[int] = None,
) -> ChatCompletionLogProbs:
"""Create OpenAI-style logprobs."""
logprobs_content = []
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
logprobs_content.append(
ChatCompletionLogProbsContent(
token=self.tokenizer.decode(token_id),
bytes=list(
self.tokenizer.decode(token_id).encode(
"utf-8", errors="replace"))))
else:
logprobs_content.append(
ChatCompletionLogProbsContent(
token=step_top_logprobs[token_id].decoded_token,
logprob=max(step_top_logprobs[token_id].logprob,
-9999.0),
bytes=list(
step_top_logprobs[token_id].decoded_token.encode(
"utf-8", errors="replace")),
top_logprobs=self._get_top_logprobs(
step_top_logprobs, num_output_top_logprobs)))
return ChatCompletionLogProbs(content=logprobs_content)

View File

@ -1,23 +1,29 @@
import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
Optional, Tuple)
Optional)
from typing import Sequence as GenericSequence
from typing import Tuple
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (CompletionRequest,
# yapf: disable
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
LogProbs, UsageInfo)
UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
from vllm.outputs import RequestOutput
from vllm.sequence import Logprob
from vllm.utils import merge_async_iterators, random_uuid
logger = init_logger(__name__)
@ -25,7 +31,7 @@ logger = init_logger(__name__)
TypeTokenIDs = List[int]
TypeTopLogProbs = List[Optional[Dict[int, float]]]
TypeCreateLogProbsFn = Callable[
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs]
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
def parse_prompt_format(prompt) -> Tuple[bool, list]:
@ -235,7 +241,7 @@ class OpenAIServingCompletion(OpenAIServing):
i]:] if output.logprobs else None
if request.logprobs is not None:
logprobs = self._create_logprobs(
logprobs = self._create_completion_logprobs(
token_ids=delta_token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
@ -317,7 +323,7 @@ class OpenAIServingCompletion(OpenAIServing):
assert top_logprobs is not None, (
"top_logprobs must be provided when logprobs "
"is requested")
logprobs = self._create_logprobs(
logprobs = self._create_completion_logprobs(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
@ -351,3 +357,59 @@ class OpenAIServingCompletion(OpenAIServing):
choices=choices,
usage=usage,
)
def _create_completion_logprobs(
self,
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
num_output_top_logprobs: int,
initial_text_offset: int = 0,
) -> CompletionLogProbs:
"""Create logprobs for OpenAI Completion API."""
out_text_offset: List[int] = []
out_token_logprobs: List[Optional[float]] = []
out_tokens: List[str] = []
out_top_logprobs: List[Optional[Dict[str, float]]] = []
last_token_len = 0
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = self.tokenizer.decode(token_id)
out_tokens.append(token)
out_token_logprobs.append(None)
out_top_logprobs.append(None)
else:
token = self._get_decoded_token(step_top_logprobs[token_id],
token_id)
token_logprob = max(step_top_logprobs[token_id].logprob,
-9999.0)
out_tokens.append(token)
out_token_logprobs.append(token_logprob)
# makes sure to add the top num_output_top_logprobs + 1
# logprobs, as defined in the openai API
# (cf. https://github.com/openai/openai-openapi/blob/
# 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
out_top_logprobs.append({
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
self._get_decoded_token(top_lp[1], top_lp[0]):
max(top_lp[1].logprob, -9999.0)
for i, top_lp in enumerate(step_top_logprobs.items())
if num_output_top_logprobs >= i
})
if len(out_text_offset) == 0:
out_text_offset.append(initial_text_offset)
else:
out_text_offset.append(out_text_offset[-1] + last_token_len)
last_token_len = len(token)
return CompletionLogProbs(
text_offset=out_text_offset,
token_logprobs=out_token_logprobs,
tokens=out_tokens,
top_logprobs=out_top_logprobs,
)

View File

@ -11,7 +11,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest,
EmbeddingRequest, ErrorResponse,
LogProbs, ModelCard, ModelList,
ModelCard, ModelList,
ModelPermission)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -75,51 +75,6 @@ class OpenAIServing:
model_cards.extend(lora_cards)
return ModelList(data=model_cards)
def _create_logprobs(
self,
token_ids: List[int],
top_logprobs: List[Optional[Dict[int, Logprob]]],
num_output_top_logprobs: Optional[int] = None,
initial_text_offset: int = 0,
) -> LogProbs:
"""Create OpenAI-style logprobs."""
logprobs = LogProbs()
last_token_len = 0
if num_output_top_logprobs:
logprobs.top_logprobs = []
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = self.tokenizer.decode(token_id)
logprobs.tokens.append(token)
logprobs.token_logprobs.append(None)
assert logprobs.top_logprobs is not None
logprobs.top_logprobs.append(None)
else:
token_logprob = step_top_logprobs[token_id].logprob
token = step_top_logprobs[token_id].decoded_token
logprobs.tokens.append(token)
token_logprob = max(token_logprob, -9999.0)
logprobs.token_logprobs.append(token_logprob)
if num_output_top_logprobs:
assert logprobs.top_logprobs is not None
logprobs.top_logprobs.append({
# Convert float("-inf") to the
# JSON-serializable float that OpenAI uses
p.decoded_token: max(p.logprob, -9999.0)
for i, p in step_top_logprobs.items()
} if step_top_logprobs else None)
if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset)
else:
logprobs.text_offset.append(logprobs.text_offset[-1] +
last_token_len)
last_token_len = len(token)
return logprobs
def create_error_response(
self,
message: str,
@ -235,3 +190,8 @@ class OpenAIServing:
f"Please reduce the length of the messages or completion.", )
else:
return input_ids, input_text
def _get_decoded_token(self, logprob: Logprob, token_id: int) -> str:
if logprob.decoded_token is not None:
return logprob.decoded_token
return self.tokenizer.decode(token_id)