[Frontend] Continuous usage stats in OpenAI completion API (#5742)

This commit is contained in:
jvlunteren 2024-07-05 19:37:09 +02:00 committed by GitHub
parent 0097bb1829
commit f1e15da6fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 110 additions and 31 deletions

View File

@ -295,25 +295,49 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI,
model_name: str):
prompt = "What is the capital of France?"
# Test stream=True, stream_options={"include_usage": False}
stream = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={"include_usage": False})
# Test stream=True, stream_options=
# {"include_usage": False, "continuous_usage_stats": False}
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={
"include_usage": False,
"continuous_usage_stats":
False,
})
async for chunk in stream:
assert chunk.usage is None
# Test stream=True, stream_options={"include_usage": True}
stream = await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={"include_usage": True})
# Test stream=True, stream_options=
# {"include_usage": False, "continuous_usage_stats": True}
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={
"include_usage": False,
"continuous_usage_stats":
True,
})
async for chunk in stream:
assert chunk.usage is None
# Test stream=True, stream_options=
# {"include_usage": True, "continuous_usage_stats": False}
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={
"include_usage": True,
"continuous_usage_stats":
False,
})
async for chunk in stream:
if chunk.choices[0].finish_reason is None:
assert chunk.usage is None
@ -328,7 +352,36 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI,
final_chunk.usage.completion_tokens)
assert final_chunk.choices == []
# Test stream=False, stream_options={"include_usage": None}
# Test stream=True, stream_options=
# {"include_usage": True, "continuous_usage_stats": True}
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=True,
stream_options={
"include_usage": True,
"continuous_usage_stats":
True,
})
async for chunk in stream:
assert chunk.usage is not None
assert chunk.usage.prompt_tokens > 0
assert chunk.usage.completion_tokens > 0
assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens +
chunk.usage.completion_tokens)
if chunk.choices[0].finish_reason is not None:
final_chunk = await stream.__anext__()
assert final_chunk.usage is not None
assert final_chunk.usage.prompt_tokens > 0
assert final_chunk.usage.completion_tokens > 0
assert final_chunk.usage.total_tokens == (
final_chunk.usage.prompt_tokens +
final_chunk.usage.completion_tokens)
assert final_chunk.choices == []
# Test stream=False, stream_options=
# {"include_usage": None}
with pytest.raises(BadRequestError):
await client.completions.create(model=model_name,
prompt=prompt,
@ -337,7 +390,8 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI,
stream=False,
stream_options={"include_usage": None})
# Test stream=False, stream_options={"include_usage": True}
# Test stream=False, stream_options=
# {"include_usage": True}
with pytest.raises(BadRequestError):
await client.completions.create(model=model_name,
prompt=prompt,
@ -346,6 +400,28 @@ async def test_completion_stream_options(client: openai.AsyncOpenAI,
stream=False,
stream_options={"include_usage": True})
# Test stream=False, stream_options=
# {"continuous_usage_stats": None}
with pytest.raises(BadRequestError):
await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"continuous_usage_stats": None})
# Test stream=False, stream_options=
# {"continuous_usage_stats": True}
with pytest.raises(BadRequestError):
await client.completions.create(
model=model_name,
prompt=prompt,
max_tokens=5,
temperature=0.0,
stream=False,
stream_options={"continuous_usage_stats": True})
@pytest.mark.asyncio
@pytest.mark.parametrize(

View File

@ -103,7 +103,8 @@ class ResponseFormat(OpenAIBaseModel):
class StreamOptions(OpenAIBaseModel):
include_usage: Optional[bool]
include_usage: Optional[bool] = True
continuous_usage_stats: Optional[bool] = True
class FunctionDefinition(OpenAIBaseModel):

View File

@ -271,16 +271,6 @@ class OpenAIServingCompletion(OpenAIServing):
previous_num_tokens[i] = len(output.token_ids)
finish_reason = output.finish_reason
stop_reason = output.stop_reason
if output.finish_reason is not None: # return final usage
prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
else:
final_usage = None
chunk = CompletionStreamResponse(
id=request_id,
@ -297,7 +287,19 @@ class OpenAIServingCompletion(OpenAIServing):
])
if (request.stream_options
and request.stream_options.include_usage):
chunk.usage = None
if (request.stream_options.continuous_usage_stats
or output.finish_reason is not None):
prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids)
usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
if request.stream_options.continuous_usage_stats:
chunk.usage = usage
else:
chunk.usage = None
response_json = chunk.model_dump_json(exclude_unset=True)
yield f"data: {response_json}\n\n"
@ -309,7 +311,7 @@ class OpenAIServingCompletion(OpenAIServing):
created=created_time,
model=model_name,
choices=[],
usage=final_usage,
usage=usage,
)
final_usage_data = (final_usage_chunk.model_dump_json(
exclude_unset=True, exclude_none=True))