[BUGFIX] [FRONTEND] Correct chat logprobs (#5029)
Co-authored-by: Breno Faria <breno.faria@intrafind.com>
This commit is contained in:
parent
e07aff9e52
commit
87d41c849d
@ -94,8 +94,10 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI):
|
|||||||
chat_completion.choices) == 1
|
chat_completion.choices) == 1
|
||||||
assert chat_completion.choices[0].message is not None
|
assert chat_completion.choices[0].message is not None
|
||||||
assert chat_completion.choices[0].logprobs is not None
|
assert chat_completion.choices[0].logprobs is not None
|
||||||
assert chat_completion.choices[0].logprobs.top_logprobs is not None
|
assert chat_completion.choices[0].logprobs.content[
|
||||||
assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5
|
0].top_logprobs is not None
|
||||||
|
assert len(
|
||||||
|
chat_completion.choices[0].logprobs.content[0].top_logprobs) == 5
|
||||||
message = chat_completion.choices[0].message
|
message = chat_completion.choices[0].message
|
||||||
assert message.content is not None and len(message.content) >= 10
|
assert message.content is not None and len(message.content) >= 10
|
||||||
assert message.role == "assistant"
|
assert message.role == "assistant"
|
||||||
|
@ -184,6 +184,26 @@ async def test_single_completion(server, client: openai.AsyncOpenAI,
|
|||||||
completion.choices[0].text) >= 5
|
completion.choices[0].text) >= 5
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
# first test base model, then test loras
|
||||||
|
"model_name",
|
||||||
|
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
||||||
|
)
|
||||||
|
async def test_no_logprobs(server, client: openai.AsyncOpenAI,
|
||||||
|
model_name: str):
|
||||||
|
# test using token IDs
|
||||||
|
completion = await client.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
prompt=[0, 0, 0, 0, 0],
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
logprobs=None,
|
||||||
|
)
|
||||||
|
choice = completion.choices[0]
|
||||||
|
assert choice.logprobs is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
# first test base model, then test loras
|
# first test base model, then test loras
|
||||||
@ -203,7 +223,72 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI,
|
|||||||
choice = completion.choices[0]
|
choice = completion.choices[0]
|
||||||
assert choice.logprobs is not None
|
assert choice.logprobs is not None
|
||||||
assert choice.logprobs.token_logprobs is not None
|
assert choice.logprobs.token_logprobs is not None
|
||||||
assert choice.logprobs.top_logprobs is None
|
assert choice.logprobs.top_logprobs is not None
|
||||||
|
assert len(choice.logprobs.top_logprobs[0]) <= 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name",
|
||||||
|
[MODEL_NAME, "zephyr-lora"],
|
||||||
|
)
|
||||||
|
async def test_some_logprobs(server, client: openai.AsyncOpenAI,
|
||||||
|
model_name: str):
|
||||||
|
# test using token IDs
|
||||||
|
completion = await client.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
prompt=[0, 0, 0, 0, 0],
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
logprobs=5,
|
||||||
|
)
|
||||||
|
choice = completion.choices[0]
|
||||||
|
assert choice.logprobs is not None
|
||||||
|
assert choice.logprobs.token_logprobs is not None
|
||||||
|
assert choice.logprobs.top_logprobs is not None
|
||||||
|
assert len(choice.logprobs.top_logprobs[0]) <= 6
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name",
|
||||||
|
[MODEL_NAME, "zephyr-lora"],
|
||||||
|
)
|
||||||
|
async def test_too_many_completion_logprobs(server, client: openai.AsyncOpenAI,
|
||||||
|
model_name: str):
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
(openai.BadRequestError, openai.APIError)): # test using token IDs
|
||||||
|
await client.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
prompt=[0, 0, 0, 0, 0],
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
logprobs=6,
|
||||||
|
)
|
||||||
|
...
|
||||||
|
with pytest.raises(
|
||||||
|
(openai.BadRequestError, openai.APIError)): # test using token IDs
|
||||||
|
stream = await client.completions.create(
|
||||||
|
model=MODEL_NAME,
|
||||||
|
prompt=[0, 0, 0, 0, 0],
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
logprobs=6,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
async for chunk in stream:
|
||||||
|
...
|
||||||
|
|
||||||
|
# the server should still work afterwards
|
||||||
|
completion = await client.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
prompt=[0, 0, 0, 0, 0],
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
)
|
||||||
|
completion = completion.choices[0].text
|
||||||
|
assert completion is not None and len(completion) >= 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -233,8 +318,10 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
|
|||||||
chat_completion.choices) == 1
|
chat_completion.choices) == 1
|
||||||
assert chat_completion.choices[0].message is not None
|
assert chat_completion.choices[0].message is not None
|
||||||
assert chat_completion.choices[0].logprobs is not None
|
assert chat_completion.choices[0].logprobs is not None
|
||||||
assert chat_completion.choices[0].logprobs.top_logprobs is not None
|
assert chat_completion.choices[0].logprobs.content[
|
||||||
assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5
|
0].top_logprobs is not None
|
||||||
|
assert len(
|
||||||
|
chat_completion.choices[0].logprobs.content[0].top_logprobs) == 5
|
||||||
message = chat_completion.choices[0].message
|
message = chat_completion.choices[0].message
|
||||||
assert message.content is not None and len(message.content) >= 10
|
assert message.content is not None and len(message.content) >= 10
|
||||||
assert message.role == "assistant"
|
assert message.role == "assistant"
|
||||||
@ -252,9 +339,13 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize(
|
||||||
async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
|
# first test base model, then test loras
|
||||||
model_name: str):
|
"model_name",
|
||||||
|
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
||||||
|
)
|
||||||
|
async def test_no_logprobs_chat(server, client: openai.AsyncOpenAI,
|
||||||
|
model_name: str):
|
||||||
messages = [{
|
messages = [{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": "you are a helpful assistant"
|
"content": "you are a helpful assistant"
|
||||||
@ -263,13 +354,92 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
|
|||||||
"content": "what is 1+1?"
|
"content": "what is 1+1?"
|
||||||
}]
|
}]
|
||||||
|
|
||||||
# Default max_logprobs is 5, so this should raise an error
|
chat_completion = await client.chat.completions.create(model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
logprobs=False)
|
||||||
|
|
||||||
|
choice = chat_completion.choices[0]
|
||||||
|
assert choice.logprobs is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
# just test 1 lora hereafter
|
||||||
|
"model_name",
|
||||||
|
[MODEL_NAME, "zephyr-lora"],
|
||||||
|
)
|
||||||
|
async def test_zero_logprobs_chat(server, client: openai.AsyncOpenAI,
|
||||||
|
model_name: str):
|
||||||
|
messages = [{
|
||||||
|
"role": "system",
|
||||||
|
"content": "you are a helpful assistant"
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"content": "what is 1+1?"
|
||||||
|
}]
|
||||||
|
|
||||||
|
chat_completion = await client.chat.completions.create(model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
logprobs=True,
|
||||||
|
top_logprobs=0)
|
||||||
|
|
||||||
|
choice = chat_completion.choices[0]
|
||||||
|
assert choice.logprobs is not None
|
||||||
|
assert choice.logprobs.content is not None
|
||||||
|
assert len(choice.logprobs.content[0].top_logprobs) <= 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model_name",
|
||||||
|
[MODEL_NAME, "zephyr-lora"],
|
||||||
|
)
|
||||||
|
async def test_some_logprobs_chat(server, client: openai.AsyncOpenAI,
|
||||||
|
model_name: str):
|
||||||
|
messages = [{
|
||||||
|
"role": "system",
|
||||||
|
"content": "you are a helpful assistant"
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"content": "what is 1+1?"
|
||||||
|
}]
|
||||||
|
|
||||||
|
chat_completion = await client.chat.completions.create(model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=5,
|
||||||
|
temperature=0.0,
|
||||||
|
logprobs=True,
|
||||||
|
top_logprobs=5)
|
||||||
|
|
||||||
|
choice = chat_completion.choices[0]
|
||||||
|
assert choice.logprobs is not None
|
||||||
|
assert choice.logprobs.content is not None
|
||||||
|
assert len(choice.logprobs.content[0].top_logprobs) <= 6
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
async def test_too_many_chat_logprobs(server, client: openai.AsyncOpenAI,
|
||||||
|
model_name: str):
|
||||||
|
messages = [{
|
||||||
|
"role": "system",
|
||||||
|
"content": "you are a helpful assistant"
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"content": "what is 1+1?"
|
||||||
|
}]
|
||||||
|
|
||||||
|
# Default max_logprobs is 20, so this should raise an error
|
||||||
with pytest.raises((openai.BadRequestError, openai.APIError)):
|
with pytest.raises((openai.BadRequestError, openai.APIError)):
|
||||||
stream = await client.chat.completions.create(model=model_name,
|
stream = await client.chat.completions.create(model=model_name,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
logprobs=True,
|
logprobs=True,
|
||||||
top_logprobs=10,
|
top_logprobs=21,
|
||||||
stream=True)
|
stream=True)
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
...
|
...
|
||||||
@ -279,25 +449,9 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
|
|||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
logprobs=True,
|
logprobs=True,
|
||||||
top_logprobs=10,
|
top_logprobs=30,
|
||||||
stream=False)
|
stream=False)
|
||||||
|
|
||||||
with pytest.raises((openai.BadRequestError, openai.APIError)):
|
|
||||||
stream = await client.completions.create(model=model_name,
|
|
||||||
prompt="Test",
|
|
||||||
max_tokens=10,
|
|
||||||
logprobs=10,
|
|
||||||
stream=True)
|
|
||||||
async for chunk in stream:
|
|
||||||
...
|
|
||||||
|
|
||||||
with pytest.raises(openai.BadRequestError):
|
|
||||||
await client.completions.create(model=model_name,
|
|
||||||
prompt="Test",
|
|
||||||
max_tokens=10,
|
|
||||||
logprobs=10,
|
|
||||||
stream=False)
|
|
||||||
|
|
||||||
# the server should still work afterwards
|
# the server should still work afterwards
|
||||||
chat_completion = await client.chat.completions.create(model=model_name,
|
chat_completion = await client.chat.completions.create(model=model_name,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
@ -744,13 +898,12 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
|
|||||||
top_logprobs=5,
|
top_logprobs=5,
|
||||||
extra_body=dict(guided_choice=TEST_CHOICE,
|
extra_body=dict(guided_choice=TEST_CHOICE,
|
||||||
guided_decoding_backend=guided_decoding_backend))
|
guided_decoding_backend=guided_decoding_backend))
|
||||||
top_logprobs = chat_completion.choices[0].logprobs.top_logprobs
|
top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs
|
||||||
|
|
||||||
# -9999.0 is the minimum logprob returned by OpenAI
|
# -9999.0 is the minimum logprob returned by OpenAI
|
||||||
assert all(
|
assert all(
|
||||||
isinstance(logprob, float) and logprob >= -9999.0
|
isinstance(token.logprob, float) and token.logprob >= -9999.0
|
||||||
for token_dict in top_logprobs
|
for token in top_logprobs)
|
||||||
for token, logprob in token_dict.items())
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -250,6 +250,19 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||||||
"('guided_json', 'guided_regex' or 'guided_choice').")
|
"('guided_json', 'guided_regex' or 'guided_choice').")
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_logprobs(cls, data):
|
||||||
|
if "top_logprobs" in data and data["top_logprobs"] is not None:
|
||||||
|
if "logprobs" not in data or data["logprobs"] is False:
|
||||||
|
raise ValueError(
|
||||||
|
"when using `top_logprobs`, `logprobs` must be set to true."
|
||||||
|
)
|
||||||
|
elif not 0 <= data["top_logprobs"] <= 20:
|
||||||
|
raise ValueError(
|
||||||
|
"`top_logprobs` must be a value in the interval [0, 20].")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class CompletionRequest(OpenAIBaseModel):
|
class CompletionRequest(OpenAIBaseModel):
|
||||||
# Ordered by official OpenAI API documentation
|
# Ordered by official OpenAI API documentation
|
||||||
@ -396,6 +409,15 @@ class CompletionRequest(OpenAIBaseModel):
|
|||||||
"('guided_json', 'guided_regex' or 'guided_choice').")
|
"('guided_json', 'guided_regex' or 'guided_choice').")
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_logprobs(cls, data):
|
||||||
|
if "logprobs" in data and data[
|
||||||
|
"logprobs"] is not None and not 0 <= data["logprobs"] <= 5:
|
||||||
|
raise ValueError(("if passed, `logprobs` must be a value",
|
||||||
|
" in the interval [0, 5]."))
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingRequest(BaseModel):
|
class EmbeddingRequest(BaseModel):
|
||||||
# Ordered by official OpenAI API documentation
|
# Ordered by official OpenAI API documentation
|
||||||
@ -415,7 +437,7 @@ class EmbeddingRequest(BaseModel):
|
|||||||
return PoolingParams(additional_data=self.additional_data)
|
return PoolingParams(additional_data=self.additional_data)
|
||||||
|
|
||||||
|
|
||||||
class LogProbs(OpenAIBaseModel):
|
class CompletionLogProbs(OpenAIBaseModel):
|
||||||
text_offset: List[int] = Field(default_factory=list)
|
text_offset: List[int] = Field(default_factory=list)
|
||||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||||
tokens: List[str] = Field(default_factory=list)
|
tokens: List[str] = Field(default_factory=list)
|
||||||
@ -425,7 +447,7 @@ class LogProbs(OpenAIBaseModel):
|
|||||||
class CompletionResponseChoice(OpenAIBaseModel):
|
class CompletionResponseChoice(OpenAIBaseModel):
|
||||||
index: int
|
index: int
|
||||||
text: str
|
text: str
|
||||||
logprobs: Optional[LogProbs] = None
|
logprobs: Optional[CompletionLogProbs] = None
|
||||||
finish_reason: Optional[str] = None
|
finish_reason: Optional[str] = None
|
||||||
stop_reason: Optional[Union[int, str]] = Field(
|
stop_reason: Optional[Union[int, str]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@ -448,7 +470,7 @@ class CompletionResponse(OpenAIBaseModel):
|
|||||||
class CompletionResponseStreamChoice(OpenAIBaseModel):
|
class CompletionResponseStreamChoice(OpenAIBaseModel):
|
||||||
index: int
|
index: int
|
||||||
text: str
|
text: str
|
||||||
logprobs: Optional[LogProbs] = None
|
logprobs: Optional[CompletionLogProbs] = None
|
||||||
finish_reason: Optional[str] = None
|
finish_reason: Optional[str] = None
|
||||||
stop_reason: Optional[Union[int, str]] = Field(
|
stop_reason: Optional[Union[int, str]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@ -488,11 +510,25 @@ class ChatMessage(OpenAIBaseModel):
|
|||||||
content: str
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionLogProb(OpenAIBaseModel):
|
||||||
|
token: str
|
||||||
|
logprob: float = -9999.0
|
||||||
|
bytes: Optional[List[int]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionLogProbsContent(ChatCompletionLogProb):
|
||||||
|
top_logprobs: List[ChatCompletionLogProb] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionLogProbs(OpenAIBaseModel):
|
||||||
|
content: Optional[List[ChatCompletionLogProbsContent]] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseChoice(OpenAIBaseModel):
|
class ChatCompletionResponseChoice(OpenAIBaseModel):
|
||||||
index: int
|
index: int
|
||||||
message: ChatMessage
|
message: ChatMessage
|
||||||
logprobs: Optional[LogProbs] = None
|
logprobs: Optional[ChatCompletionLogProbs] = None
|
||||||
finish_reason: Optional[str] = None
|
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
|
||||||
stop_reason: Optional[Union[int, str]] = None
|
stop_reason: Optional[Union[int, str]] = None
|
||||||
|
|
||||||
|
|
||||||
@ -513,8 +549,8 @@ class DeltaMessage(OpenAIBaseModel):
|
|||||||
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
|
class ChatCompletionResponseStreamChoice(OpenAIBaseModel):
|
||||||
index: int
|
index: int
|
||||||
delta: DeltaMessage
|
delta: DeltaMessage
|
||||||
logprobs: Optional[LogProbs] = None
|
logprobs: Optional[ChatCompletionLogProbs] = None
|
||||||
finish_reason: Optional[str] = None
|
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
|
||||||
stop_reason: Optional[Union[int, str]] = None
|
stop_reason: Optional[Union[int, str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
import codecs
|
import codecs
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (AsyncGenerator, AsyncIterator, Iterable, List, Optional,
|
from typing import (AsyncGenerator, AsyncIterator, Dict, Iterable, List,
|
||||||
TypedDict, Union, cast, final)
|
Optional)
|
||||||
|
from typing import Sequence as GenericSequence
|
||||||
|
from typing import TypedDict, Union, cast, final
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from openai.types.chat import ChatCompletionContentPartTextParam
|
from openai.types.chat import ChatCompletionContentPartTextParam
|
||||||
@ -10,8 +12,9 @@ from openai.types.chat import ChatCompletionContentPartTextParam
|
|||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.entrypoints.openai.protocol import (
|
from vllm.entrypoints.openai.protocol import (
|
||||||
ChatCompletionContentPartParam, ChatCompletionMessageParam,
|
ChatCompletionContentPartParam, ChatCompletionLogProb,
|
||||||
ChatCompletionRequest, ChatCompletionResponse,
|
ChatCompletionLogProbs, ChatCompletionLogProbsContent,
|
||||||
|
ChatCompletionMessageParam, ChatCompletionRequest, ChatCompletionResponse,
|
||||||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||||
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
||||||
UsageInfo)
|
UsageInfo)
|
||||||
@ -21,6 +24,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.guided_decoding import (
|
from vllm.model_executor.guided_decoding import (
|
||||||
get_guided_decoding_logits_processor)
|
get_guided_decoding_logits_processor)
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
|
from vllm.sequence import Logprob
|
||||||
from vllm.utils import random_uuid
|
from vllm.utils import random_uuid
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -283,11 +287,10 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
previous_num_tokens[i]:] if output.logprobs else None
|
previous_num_tokens[i]:] if output.logprobs else None
|
||||||
|
|
||||||
if request.logprobs:
|
if request.logprobs:
|
||||||
logprobs = self._create_logprobs(
|
logprobs = self._create_chat_logprobs(
|
||||||
token_ids=delta_token_ids,
|
token_ids=delta_token_ids,
|
||||||
top_logprobs=top_logprobs,
|
top_logprobs=top_logprobs,
|
||||||
num_output_top_logprobs=request.top_logprobs,
|
num_output_top_logprobs=request.top_logprobs,
|
||||||
initial_text_offset=len(previous_texts[i]),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logprobs = None
|
logprobs = None
|
||||||
@ -370,7 +373,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
top_logprobs = output.logprobs
|
top_logprobs = output.logprobs
|
||||||
|
|
||||||
if request.logprobs:
|
if request.logprobs:
|
||||||
logprobs = self._create_logprobs(
|
logprobs = self._create_chat_logprobs(
|
||||||
token_ids=token_ids,
|
token_ids=token_ids,
|
||||||
top_logprobs=top_logprobs,
|
top_logprobs=top_logprobs,
|
||||||
num_output_top_logprobs=request.top_logprobs,
|
num_output_top_logprobs=request.top_logprobs,
|
||||||
@ -383,8 +386,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
message=ChatMessage(role=role, content=output.text),
|
message=ChatMessage(role=role, content=output.text),
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
finish_reason=output.finish_reason,
|
finish_reason=output.finish_reason,
|
||||||
stop_reason=output.stop_reason,
|
stop_reason=output.stop_reason)
|
||||||
)
|
|
||||||
choices.append(choice_data)
|
choices.append(choice_data)
|
||||||
|
|
||||||
if request.echo:
|
if request.echo:
|
||||||
@ -414,3 +416,51 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
def _get_top_logprobs(
|
||||||
|
self, logprobs: Dict[int, Logprob],
|
||||||
|
top_logprobs: Optional[int]) -> List[ChatCompletionLogProb]:
|
||||||
|
return [
|
||||||
|
ChatCompletionLogProb(
|
||||||
|
token=self._get_decoded_token(p[1], p[0]),
|
||||||
|
logprob=max(p[1].logprob, -9999.0),
|
||||||
|
bytes=list(
|
||||||
|
self._get_decoded_token(p[1],
|
||||||
|
p[0]).encode("utf-8",
|
||||||
|
errors="replace")))
|
||||||
|
for i, p in enumerate(logprobs.items())
|
||||||
|
if top_logprobs and i < top_logprobs
|
||||||
|
]
|
||||||
|
|
||||||
|
def _create_chat_logprobs(
|
||||||
|
self,
|
||||||
|
token_ids: GenericSequence[int],
|
||||||
|
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
||||||
|
num_output_top_logprobs: Optional[int] = None,
|
||||||
|
) -> ChatCompletionLogProbs:
|
||||||
|
"""Create OpenAI-style logprobs."""
|
||||||
|
|
||||||
|
logprobs_content = []
|
||||||
|
|
||||||
|
for i, token_id in enumerate(token_ids):
|
||||||
|
step_top_logprobs = top_logprobs[i]
|
||||||
|
if step_top_logprobs is None:
|
||||||
|
logprobs_content.append(
|
||||||
|
ChatCompletionLogProbsContent(
|
||||||
|
token=self.tokenizer.decode(token_id),
|
||||||
|
bytes=list(
|
||||||
|
self.tokenizer.decode(token_id).encode(
|
||||||
|
"utf-8", errors="replace"))))
|
||||||
|
else:
|
||||||
|
logprobs_content.append(
|
||||||
|
ChatCompletionLogProbsContent(
|
||||||
|
token=step_top_logprobs[token_id].decoded_token,
|
||||||
|
logprob=max(step_top_logprobs[token_id].logprob,
|
||||||
|
-9999.0),
|
||||||
|
bytes=list(
|
||||||
|
step_top_logprobs[token_id].decoded_token.encode(
|
||||||
|
"utf-8", errors="replace")),
|
||||||
|
top_logprobs=self._get_top_logprobs(
|
||||||
|
step_top_logprobs, num_output_top_logprobs)))
|
||||||
|
|
||||||
|
return ChatCompletionLogProbs(content=logprobs_content)
|
||||||
|
@ -1,23 +1,29 @@
|
|||||||
import time
|
import time
|
||||||
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
|
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
|
||||||
Optional, Tuple)
|
Optional)
|
||||||
|
from typing import Sequence as GenericSequence
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
from vllm.entrypoints.openai.protocol import (CompletionRequest,
|
# yapf: disable
|
||||||
|
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
||||||
|
CompletionRequest,
|
||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
CompletionResponseChoice,
|
CompletionResponseChoice,
|
||||||
CompletionResponseStreamChoice,
|
CompletionResponseStreamChoice,
|
||||||
CompletionStreamResponse,
|
CompletionStreamResponse,
|
||||||
LogProbs, UsageInfo)
|
UsageInfo)
|
||||||
|
# yapf: enable
|
||||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||||
OpenAIServing)
|
OpenAIServing)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.guided_decoding import (
|
from vllm.model_executor.guided_decoding import (
|
||||||
get_guided_decoding_logits_processor)
|
get_guided_decoding_logits_processor)
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
|
from vllm.sequence import Logprob
|
||||||
from vllm.utils import merge_async_iterators, random_uuid
|
from vllm.utils import merge_async_iterators, random_uuid
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -25,7 +31,7 @@ logger = init_logger(__name__)
|
|||||||
TypeTokenIDs = List[int]
|
TypeTokenIDs = List[int]
|
||||||
TypeTopLogProbs = List[Optional[Dict[int, float]]]
|
TypeTopLogProbs = List[Optional[Dict[int, float]]]
|
||||||
TypeCreateLogProbsFn = Callable[
|
TypeCreateLogProbsFn = Callable[
|
||||||
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs]
|
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
|
||||||
|
|
||||||
|
|
||||||
def parse_prompt_format(prompt) -> Tuple[bool, list]:
|
def parse_prompt_format(prompt) -> Tuple[bool, list]:
|
||||||
@ -235,7 +241,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
i]:] if output.logprobs else None
|
i]:] if output.logprobs else None
|
||||||
|
|
||||||
if request.logprobs is not None:
|
if request.logprobs is not None:
|
||||||
logprobs = self._create_logprobs(
|
logprobs = self._create_completion_logprobs(
|
||||||
token_ids=delta_token_ids,
|
token_ids=delta_token_ids,
|
||||||
top_logprobs=top_logprobs,
|
top_logprobs=top_logprobs,
|
||||||
num_output_top_logprobs=request.logprobs,
|
num_output_top_logprobs=request.logprobs,
|
||||||
@ -317,7 +323,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
assert top_logprobs is not None, (
|
assert top_logprobs is not None, (
|
||||||
"top_logprobs must be provided when logprobs "
|
"top_logprobs must be provided when logprobs "
|
||||||
"is requested")
|
"is requested")
|
||||||
logprobs = self._create_logprobs(
|
logprobs = self._create_completion_logprobs(
|
||||||
token_ids=token_ids,
|
token_ids=token_ids,
|
||||||
top_logprobs=top_logprobs,
|
top_logprobs=top_logprobs,
|
||||||
num_output_top_logprobs=request.logprobs,
|
num_output_top_logprobs=request.logprobs,
|
||||||
@ -351,3 +357,59 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||||||
choices=choices,
|
choices=choices,
|
||||||
usage=usage,
|
usage=usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _create_completion_logprobs(
|
||||||
|
self,
|
||||||
|
token_ids: GenericSequence[int],
|
||||||
|
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
|
||||||
|
num_output_top_logprobs: int,
|
||||||
|
initial_text_offset: int = 0,
|
||||||
|
) -> CompletionLogProbs:
|
||||||
|
"""Create logprobs for OpenAI Completion API."""
|
||||||
|
out_text_offset: List[int] = []
|
||||||
|
out_token_logprobs: List[Optional[float]] = []
|
||||||
|
out_tokens: List[str] = []
|
||||||
|
out_top_logprobs: List[Optional[Dict[str, float]]] = []
|
||||||
|
|
||||||
|
last_token_len = 0
|
||||||
|
|
||||||
|
for i, token_id in enumerate(token_ids):
|
||||||
|
step_top_logprobs = top_logprobs[i]
|
||||||
|
if step_top_logprobs is None:
|
||||||
|
token = self.tokenizer.decode(token_id)
|
||||||
|
out_tokens.append(token)
|
||||||
|
out_token_logprobs.append(None)
|
||||||
|
out_top_logprobs.append(None)
|
||||||
|
else:
|
||||||
|
token = self._get_decoded_token(step_top_logprobs[token_id],
|
||||||
|
token_id)
|
||||||
|
token_logprob = max(step_top_logprobs[token_id].logprob,
|
||||||
|
-9999.0)
|
||||||
|
out_tokens.append(token)
|
||||||
|
out_token_logprobs.append(token_logprob)
|
||||||
|
|
||||||
|
# makes sure to add the top num_output_top_logprobs + 1
|
||||||
|
# logprobs, as defined in the openai API
|
||||||
|
# (cf. https://github.com/openai/openai-openapi/blob/
|
||||||
|
# 893ba52242dbd5387a97b96444ee1c742cfce9bd/openapi.yaml#L7153)
|
||||||
|
out_top_logprobs.append({
|
||||||
|
# Convert float("-inf") to the
|
||||||
|
# JSON-serializable float that OpenAI uses
|
||||||
|
self._get_decoded_token(top_lp[1], top_lp[0]):
|
||||||
|
max(top_lp[1].logprob, -9999.0)
|
||||||
|
for i, top_lp in enumerate(step_top_logprobs.items())
|
||||||
|
if num_output_top_logprobs >= i
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(out_text_offset) == 0:
|
||||||
|
out_text_offset.append(initial_text_offset)
|
||||||
|
else:
|
||||||
|
out_text_offset.append(out_text_offset[-1] + last_token_len)
|
||||||
|
last_token_len = len(token)
|
||||||
|
|
||||||
|
return CompletionLogProbs(
|
||||||
|
text_offset=out_text_offset,
|
||||||
|
token_logprobs=out_token_logprobs,
|
||||||
|
tokens=out_tokens,
|
||||||
|
top_logprobs=out_top_logprobs,
|
||||||
|
)
|
||||||
|
@ -11,7 +11,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|||||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||||
CompletionRequest,
|
CompletionRequest,
|
||||||
EmbeddingRequest, ErrorResponse,
|
EmbeddingRequest, ErrorResponse,
|
||||||
LogProbs, ModelCard, ModelList,
|
ModelCard, ModelList,
|
||||||
ModelPermission)
|
ModelPermission)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -75,51 +75,6 @@ class OpenAIServing:
|
|||||||
model_cards.extend(lora_cards)
|
model_cards.extend(lora_cards)
|
||||||
return ModelList(data=model_cards)
|
return ModelList(data=model_cards)
|
||||||
|
|
||||||
def _create_logprobs(
|
|
||||||
self,
|
|
||||||
token_ids: List[int],
|
|
||||||
top_logprobs: List[Optional[Dict[int, Logprob]]],
|
|
||||||
num_output_top_logprobs: Optional[int] = None,
|
|
||||||
initial_text_offset: int = 0,
|
|
||||||
) -> LogProbs:
|
|
||||||
"""Create OpenAI-style logprobs."""
|
|
||||||
logprobs = LogProbs()
|
|
||||||
last_token_len = 0
|
|
||||||
if num_output_top_logprobs:
|
|
||||||
logprobs.top_logprobs = []
|
|
||||||
|
|
||||||
for i, token_id in enumerate(token_ids):
|
|
||||||
step_top_logprobs = top_logprobs[i]
|
|
||||||
if step_top_logprobs is None:
|
|
||||||
token = self.tokenizer.decode(token_id)
|
|
||||||
logprobs.tokens.append(token)
|
|
||||||
logprobs.token_logprobs.append(None)
|
|
||||||
assert logprobs.top_logprobs is not None
|
|
||||||
logprobs.top_logprobs.append(None)
|
|
||||||
else:
|
|
||||||
token_logprob = step_top_logprobs[token_id].logprob
|
|
||||||
token = step_top_logprobs[token_id].decoded_token
|
|
||||||
logprobs.tokens.append(token)
|
|
||||||
token_logprob = max(token_logprob, -9999.0)
|
|
||||||
logprobs.token_logprobs.append(token_logprob)
|
|
||||||
|
|
||||||
if num_output_top_logprobs:
|
|
||||||
assert logprobs.top_logprobs is not None
|
|
||||||
logprobs.top_logprobs.append({
|
|
||||||
# Convert float("-inf") to the
|
|
||||||
# JSON-serializable float that OpenAI uses
|
|
||||||
p.decoded_token: max(p.logprob, -9999.0)
|
|
||||||
for i, p in step_top_logprobs.items()
|
|
||||||
} if step_top_logprobs else None)
|
|
||||||
|
|
||||||
if len(logprobs.text_offset) == 0:
|
|
||||||
logprobs.text_offset.append(initial_text_offset)
|
|
||||||
else:
|
|
||||||
logprobs.text_offset.append(logprobs.text_offset[-1] +
|
|
||||||
last_token_len)
|
|
||||||
last_token_len = len(token)
|
|
||||||
return logprobs
|
|
||||||
|
|
||||||
def create_error_response(
|
def create_error_response(
|
||||||
self,
|
self,
|
||||||
message: str,
|
message: str,
|
||||||
@ -235,3 +190,8 @@ class OpenAIServing:
|
|||||||
f"Please reduce the length of the messages or completion.", )
|
f"Please reduce the length of the messages or completion.", )
|
||||||
else:
|
else:
|
||||||
return input_ids, input_text
|
return input_ids, input_text
|
||||||
|
|
||||||
|
def _get_decoded_token(self, logprob: Logprob, token_id: int) -> str:
|
||||||
|
if logprob.decoded_token is not None:
|
||||||
|
return logprob.decoded_token
|
||||||
|
return self.tokenizer.decode(token_id)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user