2024-03-25 23:59:47 +09:00
|
|
|
# imports for guided decoding tests
|
|
|
|
import json
|
|
|
|
import re
|
2024-01-17 05:33:14 +00:00
|
|
|
|
2024-03-25 23:59:47 +09:00
|
|
|
import jsonschema
|
|
|
|
import openai # use the official client for correctness check
|
2024-01-17 05:33:14 +00:00
|
|
|
import pytest
|
2024-03-10 19:49:14 -07:00
|
|
|
# using Ray for overall ease of process management, parallel requests,
|
|
|
|
# and debugging.
|
|
|
|
import ray
|
2024-05-01 19:31:22 +00:00
|
|
|
import torch
|
2024-03-10 19:49:14 -07:00
|
|
|
# downloading lora to test lora requests
|
|
|
|
from huggingface_hub import snapshot_download
|
2024-04-27 13:08:24 +08:00
|
|
|
from openai import BadRequestError
|
2024-01-17 05:33:14 +00:00
|
|
|
|
2024-02-26 19:51:53 -08:00
|
|
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
|
|
|
|
2024-05-13 22:50:09 +08:00
|
|
|
from ..utils import ServerRunner
|
|
|
|
|
2024-03-10 19:49:14 -07:00
|
|
|
# any model with a chat template should work here
|
|
|
|
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
2024-05-11 11:30:37 -07:00
|
|
|
EMBEDDING_MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
|
2024-03-10 19:49:14 -07:00
|
|
|
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
|
|
|
|
# generation quality here
|
|
|
|
LORA_NAME = "typeof/zephyr-7b-beta-lora"
|
2024-01-17 05:33:14 +00:00
|
|
|
|
2024-02-29 14:13:08 -08:00
|
|
|
TEST_SCHEMA = {
|
|
|
|
"type": "object",
|
|
|
|
"properties": {
|
|
|
|
"name": {
|
|
|
|
"type": "string"
|
|
|
|
},
|
|
|
|
"age": {
|
|
|
|
"type": "integer"
|
|
|
|
},
|
|
|
|
"skills": {
|
|
|
|
"type": "array",
|
|
|
|
"items": {
|
|
|
|
"type": "string",
|
|
|
|
"maxLength": 10
|
|
|
|
},
|
|
|
|
"minItems": 3
|
|
|
|
},
|
|
|
|
"work history": {
|
|
|
|
"type": "array",
|
|
|
|
"items": {
|
|
|
|
"type": "object",
|
|
|
|
"properties": {
|
|
|
|
"company": {
|
|
|
|
"type": "string"
|
|
|
|
},
|
|
|
|
"duration": {
|
|
|
|
"type": "string"
|
|
|
|
},
|
|
|
|
"position": {
|
|
|
|
"type": "string"
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"required": ["company", "position"]
|
|
|
|
}
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"required": ["name", "age", "skills", "work history"]
|
|
|
|
}
|
|
|
|
|
2024-03-10 19:49:14 -07:00
|
|
|
TEST_REGEX = (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}"
|
|
|
|
r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)")
|
2024-02-29 14:13:08 -08:00
|
|
|
|
|
|
|
TEST_CHOICE = [
|
|
|
|
"Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", "Ruby",
|
|
|
|
"Swift", "Kotlin"
|
|
|
|
]
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
pytestmark = pytest.mark.openai
|
2024-01-17 05:33:14 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
2024-02-17 15:00:48 -05:00
|
|
|
def zephyr_lora_files():
|
|
|
|
return snapshot_download(repo_id=LORA_NAME)
|
|
|
|
|
|
|
|
|
2024-05-11 11:30:37 -07:00
|
|
|
@pytest.fixture(scope="module")
|
2024-02-17 15:00:48 -05:00
|
|
|
def server(zephyr_lora_files):
|
2024-01-17 05:33:14 +00:00
|
|
|
ray.init()
|
|
|
|
server_runner = ServerRunner.remote([
|
|
|
|
"--model",
|
|
|
|
MODEL_NAME,
|
2024-03-10 19:49:14 -07:00
|
|
|
# use half precision for speed and memory savings in CI environment
|
2024-01-17 05:33:14 +00:00
|
|
|
"--dtype",
|
2024-03-10 19:49:14 -07:00
|
|
|
"bfloat16",
|
2024-01-17 05:33:14 +00:00
|
|
|
"--max-model-len",
|
2024-01-24 17:11:07 -08:00
|
|
|
"8192",
|
|
|
|
"--enforce-eager",
|
2024-05-29 04:29:31 +08:00
|
|
|
"--gpu-memory-utilization",
|
|
|
|
"0.75",
|
2024-02-17 15:00:48 -05:00
|
|
|
# lora config below
|
|
|
|
"--enable-lora",
|
|
|
|
"--lora-modules",
|
|
|
|
f"zephyr-lora={zephyr_lora_files}",
|
|
|
|
f"zephyr-lora2={zephyr_lora_files}",
|
|
|
|
"--max-lora-rank",
|
|
|
|
"64",
|
|
|
|
"--max-cpu-loras",
|
|
|
|
"2",
|
|
|
|
"--max-num-seqs",
|
2024-04-11 09:56:48 +09:00
|
|
|
"128",
|
2024-01-17 05:33:14 +00:00
|
|
|
])
|
|
|
|
ray.get(server_runner.ready.remote())
|
|
|
|
yield server_runner
|
|
|
|
ray.shutdown()
|
|
|
|
|
|
|
|
|
2024-05-11 11:30:37 -07:00
|
|
|
@pytest.fixture(scope="module")
|
|
|
|
def embedding_server(zephyr_lora_files):
|
|
|
|
ray.shutdown()
|
|
|
|
ray.init()
|
|
|
|
server_runner = ServerRunner.remote([
|
|
|
|
"--model",
|
|
|
|
EMBEDDING_MODEL_NAME,
|
|
|
|
# use half precision for speed and memory savings in CI environment
|
|
|
|
"--dtype",
|
|
|
|
"bfloat16",
|
2024-05-29 04:29:31 +08:00
|
|
|
"--enforce-eager",
|
|
|
|
"--gpu-memory-utilization",
|
|
|
|
"0.75",
|
2024-05-11 11:30:37 -07:00
|
|
|
"--max-model-len",
|
|
|
|
"8192",
|
|
|
|
])
|
|
|
|
ray.get(server_runner.ready.remote())
|
|
|
|
yield server_runner
|
|
|
|
ray.shutdown()
|
|
|
|
|
|
|
|
|
2024-05-03 20:04:14 +02:00
|
|
|
@pytest.fixture(scope="module")
|
2024-01-17 05:33:14 +00:00
|
|
|
def client():
|
|
|
|
client = openai.AsyncOpenAI(
|
|
|
|
base_url="http://localhost:8000/v1",
|
|
|
|
api_key="token-abc123",
|
|
|
|
)
|
|
|
|
yield client
|
|
|
|
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
@pytest.mark.asyncio
|
2024-02-17 15:00:48 -05:00
|
|
|
async def test_check_models(server, client: openai.AsyncOpenAI):
|
|
|
|
models = await client.models.list()
|
|
|
|
models = models.data
|
|
|
|
served_model = models[0]
|
|
|
|
lora_models = models[1:]
|
|
|
|
assert served_model.id == MODEL_NAME
|
|
|
|
assert all(model.root == MODEL_NAME for model in models)
|
|
|
|
assert lora_models[0].id == "zephyr-lora"
|
|
|
|
assert lora_models[1].id == "zephyr-lora2"
|
|
|
|
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
@pytest.mark.asyncio
|
2024-02-17 15:00:48 -05:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
# first test base model, then test loras
|
|
|
|
"model_name",
|
|
|
|
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
|
|
|
)
|
|
|
|
async def test_single_completion(server, client: openai.AsyncOpenAI,
|
|
|
|
model_name: str):
|
|
|
|
completion = await client.completions.create(model=model_name,
|
2024-01-17 05:33:14 +00:00
|
|
|
prompt="Hello, my name is",
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0)
|
|
|
|
|
|
|
|
assert completion.id is not None
|
|
|
|
assert completion.choices is not None and len(completion.choices) == 1
|
|
|
|
assert completion.choices[0].text is not None and len(
|
|
|
|
completion.choices[0].text) >= 5
|
|
|
|
assert completion.choices[0].finish_reason == "length"
|
|
|
|
assert completion.usage == openai.types.CompletionUsage(
|
|
|
|
completion_tokens=5, prompt_tokens=6, total_tokens=11)
|
|
|
|
|
2024-01-18 16:45:14 -08:00
|
|
|
# test using token IDs
|
|
|
|
completion = await client.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
prompt=[0, 0, 0, 0, 0],
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
)
|
|
|
|
assert completion.choices[0].text is not None and len(
|
|
|
|
completion.choices[0].text) >= 5
|
|
|
|
|
2024-01-17 05:33:14 +00:00
|
|
|
|
2024-05-30 11:52:14 +02:00
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
# first test base model, then test loras
|
|
|
|
"model_name",
|
|
|
|
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
|
|
|
)
|
|
|
|
async def test_no_logprobs(server, client: openai.AsyncOpenAI,
|
|
|
|
model_name: str):
|
|
|
|
# test using token IDs
|
|
|
|
completion = await client.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
prompt=[0, 0, 0, 0, 0],
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
logprobs=None,
|
|
|
|
)
|
|
|
|
choice = completion.choices[0]
|
|
|
|
assert choice.logprobs is None
|
|
|
|
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
@pytest.mark.asyncio
|
2024-03-30 00:38:21 +08:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
# first test base model, then test loras
|
|
|
|
"model_name",
|
|
|
|
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
|
|
|
)
|
|
|
|
async def test_zero_logprobs(server, client: openai.AsyncOpenAI,
|
|
|
|
model_name: str):
|
|
|
|
# test using token IDs
|
|
|
|
completion = await client.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
prompt=[0, 0, 0, 0, 0],
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
logprobs=0,
|
|
|
|
)
|
|
|
|
choice = completion.choices[0]
|
|
|
|
assert choice.logprobs is not None
|
|
|
|
assert choice.logprobs.token_logprobs is not None
|
2024-05-30 11:52:14 +02:00
|
|
|
assert choice.logprobs.top_logprobs is not None
|
2024-06-04 09:59:30 +09:00
|
|
|
assert len(choice.logprobs.top_logprobs[0]) == 1
|
2024-05-30 11:52:14 +02:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"model_name",
|
|
|
|
[MODEL_NAME, "zephyr-lora"],
|
|
|
|
)
|
|
|
|
async def test_some_logprobs(server, client: openai.AsyncOpenAI,
|
|
|
|
model_name: str):
|
|
|
|
# test using token IDs
|
|
|
|
completion = await client.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
prompt=[0, 0, 0, 0, 0],
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
logprobs=5,
|
|
|
|
)
|
|
|
|
choice = completion.choices[0]
|
|
|
|
assert choice.logprobs is not None
|
|
|
|
assert choice.logprobs.token_logprobs is not None
|
|
|
|
assert choice.logprobs.top_logprobs is not None
|
2024-06-04 09:59:30 +09:00
|
|
|
assert 5 <= len(choice.logprobs.top_logprobs[0]) <= 6
|
2024-05-30 11:52:14 +02:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"model_name",
|
|
|
|
[MODEL_NAME, "zephyr-lora"],
|
|
|
|
)
|
|
|
|
async def test_too_many_completion_logprobs(server, client: openai.AsyncOpenAI,
|
|
|
|
model_name: str):
|
|
|
|
|
|
|
|
with pytest.raises(
|
|
|
|
(openai.BadRequestError, openai.APIError)): # test using token IDs
|
|
|
|
await client.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
prompt=[0, 0, 0, 0, 0],
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
logprobs=6,
|
|
|
|
)
|
|
|
|
...
|
|
|
|
with pytest.raises(
|
|
|
|
(openai.BadRequestError, openai.APIError)): # test using token IDs
|
|
|
|
stream = await client.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
prompt=[0, 0, 0, 0, 0],
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
logprobs=6,
|
|
|
|
stream=True,
|
|
|
|
)
|
|
|
|
async for chunk in stream:
|
|
|
|
...
|
|
|
|
|
|
|
|
# the server should still work afterwards
|
|
|
|
completion = await client.completions.create(
|
|
|
|
model=model_name,
|
|
|
|
prompt=[0, 0, 0, 0, 0],
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
)
|
|
|
|
completion = completion.choices[0].text
|
|
|
|
assert completion is not None and len(completion) >= 0
|
2024-03-30 00:38:21 +08:00
|
|
|
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
@pytest.mark.asyncio
|
2024-02-17 15:00:48 -05:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
# just test 1 lora hereafter
|
|
|
|
"model_name",
|
|
|
|
[MODEL_NAME, "zephyr-lora"],
|
|
|
|
)
|
|
|
|
async def test_single_chat_session(server, client: openai.AsyncOpenAI,
|
|
|
|
model_name: str):
|
2024-01-17 05:33:14 +00:00
|
|
|
messages = [{
|
|
|
|
"role": "system",
|
|
|
|
"content": "you are a helpful assistant"
|
|
|
|
}, {
|
|
|
|
"role": "user",
|
|
|
|
"content": "what is 1+1?"
|
|
|
|
}]
|
|
|
|
|
|
|
|
# test single completion
|
2024-02-25 18:39:34 -08:00
|
|
|
chat_completion = await client.chat.completions.create(model=model_name,
|
|
|
|
messages=messages,
|
|
|
|
max_tokens=10,
|
|
|
|
logprobs=True,
|
2024-03-04 11:54:06 -08:00
|
|
|
top_logprobs=5)
|
2024-01-17 05:33:14 +00:00
|
|
|
assert chat_completion.id is not None
|
|
|
|
assert chat_completion.choices is not None and len(
|
|
|
|
chat_completion.choices) == 1
|
|
|
|
assert chat_completion.choices[0].message is not None
|
2024-02-25 18:39:34 -08:00
|
|
|
assert chat_completion.choices[0].logprobs is not None
|
2024-05-30 11:52:14 +02:00
|
|
|
assert chat_completion.choices[0].logprobs.content[
|
|
|
|
0].top_logprobs is not None
|
|
|
|
assert len(
|
|
|
|
chat_completion.choices[0].logprobs.content[0].top_logprobs) == 5
|
2024-01-17 05:33:14 +00:00
|
|
|
message = chat_completion.choices[0].message
|
|
|
|
assert message.content is not None and len(message.content) >= 10
|
|
|
|
assert message.role == "assistant"
|
|
|
|
messages.append({"role": "assistant", "content": message.content})
|
|
|
|
|
|
|
|
# test multi-turn dialogue
|
|
|
|
messages.append({"role": "user", "content": "express your result in json"})
|
|
|
|
chat_completion = await client.chat.completions.create(
|
2024-03-04 11:54:06 -08:00
|
|
|
model=model_name,
|
2024-01-17 05:33:14 +00:00
|
|
|
messages=messages,
|
|
|
|
max_tokens=10,
|
|
|
|
)
|
|
|
|
message = chat_completion.choices[0].message
|
|
|
|
assert message.content is not None and len(message.content) >= 0
|
2024-03-04 11:54:06 -08:00
|
|
|
|
|
|
|
|
2024-05-30 11:52:14 +02:00
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
# first test base model, then test loras
|
|
|
|
"model_name",
|
|
|
|
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
|
|
|
)
|
|
|
|
async def test_no_logprobs_chat(server, client: openai.AsyncOpenAI,
|
|
|
|
model_name: str):
|
|
|
|
messages = [{
|
|
|
|
"role": "system",
|
|
|
|
"content": "you are a helpful assistant"
|
|
|
|
}, {
|
|
|
|
"role": "user",
|
|
|
|
"content": "what is 1+1?"
|
|
|
|
}]
|
|
|
|
|
|
|
|
chat_completion = await client.chat.completions.create(model=model_name,
|
|
|
|
messages=messages,
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
logprobs=False)
|
|
|
|
|
|
|
|
choice = chat_completion.choices[0]
|
|
|
|
assert choice.logprobs is None
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
# just test 1 lora hereafter
|
|
|
|
"model_name",
|
|
|
|
[MODEL_NAME, "zephyr-lora"],
|
|
|
|
)
|
|
|
|
async def test_zero_logprobs_chat(server, client: openai.AsyncOpenAI,
|
|
|
|
model_name: str):
|
|
|
|
messages = [{
|
|
|
|
"role": "system",
|
|
|
|
"content": "you are a helpful assistant"
|
|
|
|
}, {
|
|
|
|
"role": "user",
|
|
|
|
"content": "what is 1+1?"
|
|
|
|
}]
|
|
|
|
|
|
|
|
chat_completion = await client.chat.completions.create(model=model_name,
|
|
|
|
messages=messages,
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
logprobs=True,
|
|
|
|
top_logprobs=0)
|
|
|
|
|
|
|
|
choice = chat_completion.choices[0]
|
|
|
|
assert choice.logprobs is not None
|
|
|
|
assert choice.logprobs.content is not None
|
|
|
|
assert len(choice.logprobs.content[0].top_logprobs) <= 1
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"model_name",
|
|
|
|
[MODEL_NAME, "zephyr-lora"],
|
|
|
|
)
|
|
|
|
async def test_some_logprobs_chat(server, client: openai.AsyncOpenAI,
|
|
|
|
model_name: str):
|
|
|
|
messages = [{
|
|
|
|
"role": "system",
|
|
|
|
"content": "you are a helpful assistant"
|
|
|
|
}, {
|
|
|
|
"role": "user",
|
|
|
|
"content": "what is 1+1?"
|
|
|
|
}]
|
|
|
|
|
|
|
|
chat_completion = await client.chat.completions.create(model=model_name,
|
|
|
|
messages=messages,
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
logprobs=True,
|
|
|
|
top_logprobs=5)
|
|
|
|
|
|
|
|
choice = chat_completion.choices[0]
|
|
|
|
assert choice.logprobs is not None
|
|
|
|
assert choice.logprobs.content is not None
|
|
|
|
assert len(choice.logprobs.content[0].top_logprobs) <= 6
|
|
|
|
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
@pytest.mark.asyncio
|
2024-03-04 11:54:06 -08:00
|
|
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
2024-05-30 11:52:14 +02:00
|
|
|
async def test_too_many_chat_logprobs(server, client: openai.AsyncOpenAI,
|
|
|
|
model_name: str):
|
2024-03-04 11:54:06 -08:00
|
|
|
messages = [{
|
|
|
|
"role": "system",
|
|
|
|
"content": "you are a helpful assistant"
|
|
|
|
}, {
|
|
|
|
"role": "user",
|
|
|
|
"content": "what is 1+1?"
|
|
|
|
}]
|
|
|
|
|
2024-05-30 11:52:14 +02:00
|
|
|
# Default max_logprobs is 20, so this should raise an error
|
2024-03-04 11:54:06 -08:00
|
|
|
with pytest.raises((openai.BadRequestError, openai.APIError)):
|
|
|
|
stream = await client.chat.completions.create(model=model_name,
|
|
|
|
messages=messages,
|
|
|
|
max_tokens=10,
|
|
|
|
logprobs=True,
|
2024-05-30 11:52:14 +02:00
|
|
|
top_logprobs=21,
|
2024-03-04 11:54:06 -08:00
|
|
|
stream=True)
|
|
|
|
async for chunk in stream:
|
|
|
|
...
|
|
|
|
|
|
|
|
with pytest.raises(openai.BadRequestError):
|
|
|
|
await client.chat.completions.create(model=model_name,
|
|
|
|
messages=messages,
|
|
|
|
max_tokens=10,
|
|
|
|
logprobs=True,
|
2024-05-30 11:52:14 +02:00
|
|
|
top_logprobs=30,
|
2024-03-04 11:54:06 -08:00
|
|
|
stream=False)
|
|
|
|
|
|
|
|
# the server should still work afterwards
|
|
|
|
chat_completion = await client.chat.completions.create(model=model_name,
|
|
|
|
messages=messages,
|
|
|
|
max_tokens=10,
|
|
|
|
stream=False)
|
|
|
|
message = chat_completion.choices[0].message
|
|
|
|
assert message.content is not None and len(message.content) >= 0
|
2024-01-17 05:33:14 +00:00
|
|
|
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
@pytest.mark.asyncio
|
2024-02-17 15:00:48 -05:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
# just test 1 lora hereafter
|
|
|
|
"model_name",
|
|
|
|
[MODEL_NAME, "zephyr-lora"],
|
|
|
|
)
|
|
|
|
async def test_completion_streaming(server, client: openai.AsyncOpenAI,
|
|
|
|
model_name: str):
|
2024-01-17 05:33:14 +00:00
|
|
|
prompt = "What is an LLM?"
|
|
|
|
|
|
|
|
single_completion = await client.completions.create(
|
2024-02-17 15:00:48 -05:00
|
|
|
model=model_name,
|
2024-01-17 05:33:14 +00:00
|
|
|
prompt=prompt,
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
)
|
|
|
|
single_output = single_completion.choices[0].text
|
2024-02-25 18:39:34 -08:00
|
|
|
stream = await client.completions.create(model=model_name,
|
|
|
|
prompt=prompt,
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
stream=True)
|
2024-01-17 05:33:14 +00:00
|
|
|
chunks = []
|
2024-03-25 10:14:34 -07:00
|
|
|
finish_reason_count = 0
|
2024-01-17 05:33:14 +00:00
|
|
|
async for chunk in stream:
|
|
|
|
chunks.append(chunk.choices[0].text)
|
2024-03-25 10:14:34 -07:00
|
|
|
if chunk.choices[0].finish_reason is not None:
|
|
|
|
finish_reason_count += 1
|
|
|
|
# finish reason should only return in last block
|
|
|
|
assert finish_reason_count == 1
|
2024-01-17 05:33:14 +00:00
|
|
|
assert chunk.choices[0].finish_reason == "length"
|
2024-03-25 10:14:34 -07:00
|
|
|
assert chunk.choices[0].text
|
2024-01-17 05:33:14 +00:00
|
|
|
assert "".join(chunks) == single_output
|
|
|
|
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
@pytest.mark.asyncio
|
2024-02-17 15:00:48 -05:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
# just test 1 lora hereafter
|
|
|
|
"model_name",
|
|
|
|
[MODEL_NAME, "zephyr-lora"],
|
|
|
|
)
|
|
|
|
async def test_chat_streaming(server, client: openai.AsyncOpenAI,
|
|
|
|
model_name: str):
|
2024-01-17 05:33:14 +00:00
|
|
|
messages = [{
|
|
|
|
"role": "system",
|
|
|
|
"content": "you are a helpful assistant"
|
|
|
|
}, {
|
|
|
|
"role": "user",
|
|
|
|
"content": "what is 1+1?"
|
|
|
|
}]
|
|
|
|
|
|
|
|
# test single completion
|
|
|
|
chat_completion = await client.chat.completions.create(
|
2024-02-17 15:00:48 -05:00
|
|
|
model=model_name,
|
2024-01-17 05:33:14 +00:00
|
|
|
messages=messages,
|
|
|
|
max_tokens=10,
|
|
|
|
temperature=0.0,
|
|
|
|
)
|
|
|
|
output = chat_completion.choices[0].message.content
|
|
|
|
stop_reason = chat_completion.choices[0].finish_reason
|
|
|
|
|
|
|
|
# test streaming
|
|
|
|
stream = await client.chat.completions.create(
|
2024-02-17 15:00:48 -05:00
|
|
|
model=model_name,
|
2024-01-17 05:33:14 +00:00
|
|
|
messages=messages,
|
|
|
|
max_tokens=10,
|
|
|
|
temperature=0.0,
|
|
|
|
stream=True,
|
|
|
|
)
|
|
|
|
chunks = []
|
2024-03-25 10:14:34 -07:00
|
|
|
finish_reason_count = 0
|
2024-01-17 05:33:14 +00:00
|
|
|
async for chunk in stream:
|
|
|
|
delta = chunk.choices[0].delta
|
|
|
|
if delta.role:
|
|
|
|
assert delta.role == "assistant"
|
|
|
|
if delta.content:
|
|
|
|
chunks.append(delta.content)
|
2024-03-25 10:14:34 -07:00
|
|
|
if chunk.choices[0].finish_reason is not None:
|
|
|
|
finish_reason_count += 1
|
|
|
|
# finish reason should only return in last block
|
|
|
|
assert finish_reason_count == 1
|
2024-01-17 05:33:14 +00:00
|
|
|
assert chunk.choices[0].finish_reason == stop_reason
|
2024-03-25 10:14:34 -07:00
|
|
|
assert delta.content
|
2024-01-17 05:33:14 +00:00
|
|
|
assert "".join(chunks) == output
|
|
|
|
|
|
|
|
|
2024-06-10 17:22:09 +03:00
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"model_name",
|
|
|
|
["HuggingFaceH4/zephyr-7b-beta", "zephyr-lora"],
|
|
|
|
)
|
|
|
|
async def test_chat_completion_stream_options(server,
|
|
|
|
client: openai.AsyncOpenAI,
|
|
|
|
model_name: str):
|
|
|
|
messages = [{
|
|
|
|
"role": "system",
|
|
|
|
"content": "You are a helpful assistant."
|
|
|
|
}, {
|
|
|
|
"role": "user",
|
|
|
|
"content": "What is the capital of France?"
|
|
|
|
}]
|
|
|
|
|
|
|
|
# Test stream=True, stream_options={"include_usage": False}
|
|
|
|
stream = await client.chat.completions.create(
|
|
|
|
model=model_name,
|
|
|
|
messages=messages,
|
|
|
|
max_tokens=10,
|
|
|
|
temperature=0.0,
|
|
|
|
stream=True,
|
|
|
|
stream_options={"include_usage": False})
|
|
|
|
async for chunk in stream:
|
|
|
|
assert chunk.usage is None
|
|
|
|
|
|
|
|
# Test stream=True, stream_options={"include_usage": True}
|
|
|
|
stream = await client.chat.completions.create(
|
|
|
|
model=model_name,
|
|
|
|
messages=messages,
|
|
|
|
max_tokens=10,
|
|
|
|
temperature=0.0,
|
|
|
|
stream=True,
|
|
|
|
stream_options={"include_usage": True})
|
|
|
|
|
|
|
|
async for chunk in stream:
|
|
|
|
if chunk.choices[0].finish_reason is None:
|
|
|
|
assert chunk.usage is None
|
|
|
|
else:
|
|
|
|
assert chunk.usage is None
|
|
|
|
final_chunk = await stream.__anext__()
|
|
|
|
assert final_chunk.usage is not None
|
|
|
|
assert final_chunk.usage.prompt_tokens > 0
|
|
|
|
assert final_chunk.usage.completion_tokens > 0
|
|
|
|
assert final_chunk.usage.total_tokens == (
|
|
|
|
final_chunk.usage.prompt_tokens +
|
|
|
|
final_chunk.usage.completion_tokens)
|
|
|
|
assert final_chunk.choices == []
|
|
|
|
|
|
|
|
# Test stream=False, stream_options={"include_usage": None}
|
|
|
|
with pytest.raises(BadRequestError):
|
|
|
|
await client.chat.completions.create(
|
|
|
|
model=model_name,
|
|
|
|
messages=messages,
|
|
|
|
max_tokens=10,
|
|
|
|
temperature=0.0,
|
|
|
|
stream=False,
|
|
|
|
stream_options={"include_usage": None})
|
|
|
|
|
|
|
|
# Test stream=False, stream_options={"include_usage": True}
|
|
|
|
with pytest.raises(BadRequestError):
|
|
|
|
await client.chat.completions.create(
|
|
|
|
model=model_name,
|
|
|
|
messages=messages,
|
|
|
|
max_tokens=10,
|
|
|
|
temperature=0.0,
|
|
|
|
stream=False,
|
|
|
|
stream_options={"include_usage": True})
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"model_name",
|
|
|
|
["HuggingFaceH4/zephyr-7b-beta", "zephyr-lora"],
|
|
|
|
)
|
|
|
|
async def test_completion_stream_options(server, client: openai.AsyncOpenAI,
|
|
|
|
model_name: str):
|
|
|
|
prompt = "What is the capital of France?"
|
|
|
|
|
|
|
|
# Test stream=True, stream_options={"include_usage": False}
|
|
|
|
stream = await client.completions.create(
|
|
|
|
model=model_name,
|
|
|
|
prompt=prompt,
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
stream=True,
|
|
|
|
stream_options={"include_usage": False})
|
|
|
|
async for chunk in stream:
|
|
|
|
assert chunk.usage is None
|
|
|
|
|
|
|
|
# Test stream=True, stream_options={"include_usage": True}
|
|
|
|
stream = await client.completions.create(
|
|
|
|
model=model_name,
|
|
|
|
prompt=prompt,
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
stream=True,
|
|
|
|
stream_options={"include_usage": True})
|
|
|
|
async for chunk in stream:
|
|
|
|
if chunk.choices[0].finish_reason is None:
|
|
|
|
assert chunk.usage is None
|
|
|
|
else:
|
|
|
|
assert chunk.usage is None
|
|
|
|
final_chunk = await stream.__anext__()
|
|
|
|
assert final_chunk.usage is not None
|
|
|
|
assert final_chunk.usage.prompt_tokens > 0
|
|
|
|
assert final_chunk.usage.completion_tokens > 0
|
|
|
|
assert final_chunk.usage.total_tokens == (
|
|
|
|
final_chunk.usage.prompt_tokens +
|
|
|
|
final_chunk.usage.completion_tokens)
|
|
|
|
assert final_chunk.choices == []
|
|
|
|
|
|
|
|
# Test stream=False, stream_options={"include_usage": None}
|
|
|
|
with pytest.raises(BadRequestError):
|
|
|
|
await client.completions.create(model=model_name,
|
|
|
|
prompt=prompt,
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
stream=False,
|
|
|
|
stream_options={"include_usage": None})
|
|
|
|
|
|
|
|
# Test stream=False, stream_options={"include_usage": True}
|
|
|
|
with pytest.raises(BadRequestError):
|
|
|
|
await client.completions.create(model=model_name,
|
|
|
|
prompt=prompt,
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
stream=False,
|
|
|
|
stream_options={"include_usage": True})
|
|
|
|
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
@pytest.mark.asyncio
|
2024-02-17 15:00:48 -05:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
# just test 1 lora hereafter
|
|
|
|
"model_name",
|
|
|
|
[MODEL_NAME, "zephyr-lora"],
|
|
|
|
)
|
|
|
|
async def test_batch_completions(server, client: openai.AsyncOpenAI,
|
|
|
|
model_name: str):
|
2024-01-24 17:11:07 -08:00
|
|
|
# test simple list
|
|
|
|
batch = await client.completions.create(
|
2024-02-17 15:00:48 -05:00
|
|
|
model=model_name,
|
2024-01-24 17:11:07 -08:00
|
|
|
prompt=["Hello, my name is", "Hello, my name is"],
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
)
|
|
|
|
assert len(batch.choices) == 2
|
|
|
|
assert batch.choices[0].text == batch.choices[1].text
|
|
|
|
|
|
|
|
# test n = 2
|
|
|
|
batch = await client.completions.create(
|
2024-02-17 15:00:48 -05:00
|
|
|
model=model_name,
|
2024-01-24 17:11:07 -08:00
|
|
|
prompt=["Hello, my name is", "Hello, my name is"],
|
|
|
|
n=2,
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
extra_body=dict(
|
2024-03-10 19:49:14 -07:00
|
|
|
# NOTE: this has to be true for n > 1 in vLLM, but not necessary
|
|
|
|
# for official client.
|
2024-01-24 17:11:07 -08:00
|
|
|
use_beam_search=True),
|
|
|
|
)
|
|
|
|
assert len(batch.choices) == 4
|
|
|
|
assert batch.choices[0].text != batch.choices[
|
|
|
|
1].text, "beam search should be different"
|
|
|
|
assert batch.choices[0].text == batch.choices[
|
|
|
|
2].text, "two copies of the same prompt should be the same"
|
|
|
|
assert batch.choices[1].text == batch.choices[
|
|
|
|
3].text, "two copies of the same prompt should be the same"
|
|
|
|
|
|
|
|
# test streaming
|
|
|
|
batch = await client.completions.create(
|
2024-02-17 15:00:48 -05:00
|
|
|
model=model_name,
|
2024-01-24 17:11:07 -08:00
|
|
|
prompt=["Hello, my name is", "Hello, my name is"],
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
stream=True,
|
|
|
|
)
|
|
|
|
texts = [""] * 2
|
|
|
|
async for chunk in batch:
|
|
|
|
assert len(chunk.choices) == 1
|
|
|
|
choice = chunk.choices[0]
|
|
|
|
texts[choice.index] += choice.text
|
|
|
|
assert texts[0] == texts[1]
|
|
|
|
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
@pytest.mark.asyncio
|
2024-02-26 19:51:53 -08:00
|
|
|
async def test_logits_bias(server, client: openai.AsyncOpenAI):
|
|
|
|
prompt = "Hello, my name is"
|
|
|
|
max_tokens = 5
|
|
|
|
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
|
|
|
|
|
|
|
# Test exclusive selection
|
|
|
|
token_id = 1000
|
|
|
|
completion = await client.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
prompt=prompt,
|
|
|
|
max_tokens=max_tokens,
|
|
|
|
temperature=0.0,
|
|
|
|
logit_bias={str(token_id): 100},
|
2024-02-29 14:13:08 -08:00
|
|
|
seed=42,
|
2024-02-26 19:51:53 -08:00
|
|
|
)
|
|
|
|
assert completion.choices[0].text is not None and len(
|
|
|
|
completion.choices[0].text) >= 5
|
|
|
|
response_tokens = tokenizer(completion.choices[0].text,
|
|
|
|
add_special_tokens=False)["input_ids"]
|
|
|
|
expected_tokens = tokenizer(tokenizer.decode([token_id] * 5),
|
|
|
|
add_special_tokens=False)["input_ids"]
|
|
|
|
assert all([
|
|
|
|
response == expected
|
|
|
|
for response, expected in zip(response_tokens, expected_tokens)
|
|
|
|
])
|
|
|
|
|
|
|
|
# Test ban
|
|
|
|
completion = await client.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
prompt=prompt,
|
|
|
|
max_tokens=max_tokens,
|
|
|
|
temperature=0.0,
|
|
|
|
)
|
|
|
|
response_tokens = tokenizer(completion.choices[0].text,
|
|
|
|
add_special_tokens=False)["input_ids"]
|
|
|
|
first_response = completion.choices[0].text
|
|
|
|
completion = await client.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
prompt=prompt,
|
|
|
|
max_tokens=max_tokens,
|
|
|
|
temperature=0.0,
|
|
|
|
logit_bias={str(token): -100
|
|
|
|
for token in response_tokens},
|
|
|
|
)
|
|
|
|
assert first_response != completion.choices[0].text
|
|
|
|
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
@pytest.mark.asyncio
|
2024-04-16 08:54:57 +03:00
|
|
|
@pytest.mark.parametrize("guided_decoding_backend",
|
|
|
|
["outlines", "lm-format-enforcer"])
|
|
|
|
async def test_guided_json_completion(server, client: openai.AsyncOpenAI,
|
|
|
|
guided_decoding_backend: str):
|
2024-02-29 14:13:08 -08:00
|
|
|
completion = await client.completions.create(
|
|
|
|
model=MODEL_NAME,
|
2024-03-10 19:49:14 -07:00
|
|
|
prompt=f"Give an example JSON for an employee profile "
|
|
|
|
f"that fits this schema: {TEST_SCHEMA}",
|
2024-02-29 14:13:08 -08:00
|
|
|
n=3,
|
|
|
|
temperature=1.0,
|
|
|
|
max_tokens=500,
|
2024-04-16 08:54:57 +03:00
|
|
|
extra_body=dict(guided_json=TEST_SCHEMA,
|
|
|
|
guided_decoding_backend=guided_decoding_backend))
|
2024-02-29 14:13:08 -08:00
|
|
|
|
|
|
|
assert completion.id is not None
|
|
|
|
assert completion.choices is not None and len(completion.choices) == 3
|
|
|
|
for i in range(3):
|
|
|
|
assert completion.choices[i].text is not None
|
|
|
|
output_json = json.loads(completion.choices[i].text)
|
|
|
|
jsonschema.validate(instance=output_json, schema=TEST_SCHEMA)
|
|
|
|
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
@pytest.mark.asyncio
|
2024-04-16 08:54:57 +03:00
|
|
|
@pytest.mark.parametrize("guided_decoding_backend",
|
|
|
|
["outlines", "lm-format-enforcer"])
|
|
|
|
async def test_guided_json_chat(server, client: openai.AsyncOpenAI,
|
|
|
|
guided_decoding_backend: str):
|
2024-02-29 14:13:08 -08:00
|
|
|
messages = [{
|
|
|
|
"role": "system",
|
|
|
|
"content": "you are a helpful assistant"
|
|
|
|
}, {
|
2024-03-10 19:49:14 -07:00
|
|
|
"role":
|
|
|
|
"user",
|
|
|
|
"content":
|
|
|
|
f"Give an example JSON for an employee profile that "
|
|
|
|
f"fits this schema: {TEST_SCHEMA}"
|
2024-02-29 14:13:08 -08:00
|
|
|
}]
|
|
|
|
chat_completion = await client.chat.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
messages=messages,
|
2024-04-16 08:54:57 +03:00
|
|
|
max_tokens=1000,
|
|
|
|
extra_body=dict(guided_json=TEST_SCHEMA,
|
|
|
|
guided_decoding_backend=guided_decoding_backend))
|
2024-02-29 14:13:08 -08:00
|
|
|
message = chat_completion.choices[0].message
|
|
|
|
assert message.content is not None
|
|
|
|
json1 = json.loads(message.content)
|
|
|
|
jsonschema.validate(instance=json1, schema=TEST_SCHEMA)
|
|
|
|
|
|
|
|
messages.append({"role": "assistant", "content": message.content})
|
|
|
|
messages.append({
|
|
|
|
"role":
|
|
|
|
"user",
|
|
|
|
"content":
|
|
|
|
"Give me another one with a different name and age"
|
|
|
|
})
|
|
|
|
chat_completion = await client.chat.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
messages=messages,
|
2024-04-16 08:54:57 +03:00
|
|
|
max_tokens=1000,
|
|
|
|
extra_body=dict(guided_json=TEST_SCHEMA,
|
|
|
|
guided_decoding_backend=guided_decoding_backend))
|
2024-02-29 14:13:08 -08:00
|
|
|
message = chat_completion.choices[0].message
|
|
|
|
assert message.content is not None
|
|
|
|
json2 = json.loads(message.content)
|
|
|
|
jsonschema.validate(instance=json2, schema=TEST_SCHEMA)
|
|
|
|
assert json1["name"] != json2["name"]
|
|
|
|
assert json1["age"] != json2["age"]
|
|
|
|
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
@pytest.mark.asyncio
|
2024-04-16 08:54:57 +03:00
|
|
|
@pytest.mark.parametrize("guided_decoding_backend",
|
|
|
|
["outlines", "lm-format-enforcer"])
|
|
|
|
async def test_guided_regex_completion(server, client: openai.AsyncOpenAI,
|
|
|
|
guided_decoding_backend: str):
|
2024-02-29 14:13:08 -08:00
|
|
|
completion = await client.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
prompt=f"Give an example IPv4 address with this regex: {TEST_REGEX}",
|
|
|
|
n=3,
|
|
|
|
temperature=1.0,
|
|
|
|
max_tokens=20,
|
2024-04-16 08:54:57 +03:00
|
|
|
extra_body=dict(guided_regex=TEST_REGEX,
|
|
|
|
guided_decoding_backend=guided_decoding_backend))
|
2024-02-29 14:13:08 -08:00
|
|
|
|
|
|
|
assert completion.id is not None
|
|
|
|
assert completion.choices is not None and len(completion.choices) == 3
|
|
|
|
for i in range(3):
|
|
|
|
assert completion.choices[i].text is not None
|
|
|
|
assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None
|
|
|
|
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
@pytest.mark.asyncio
|
2024-04-16 08:54:57 +03:00
|
|
|
@pytest.mark.parametrize("guided_decoding_backend",
|
|
|
|
["outlines", "lm-format-enforcer"])
|
|
|
|
async def test_guided_regex_chat(server, client: openai.AsyncOpenAI,
|
|
|
|
guided_decoding_backend: str):
|
2024-02-29 14:13:08 -08:00
|
|
|
messages = [{
|
|
|
|
"role": "system",
|
|
|
|
"content": "you are a helpful assistant"
|
|
|
|
}, {
|
|
|
|
"role":
|
|
|
|
"user",
|
|
|
|
"content":
|
|
|
|
f"Give an example IP address with this regex: {TEST_REGEX}"
|
|
|
|
}]
|
|
|
|
chat_completion = await client.chat.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
messages=messages,
|
|
|
|
max_tokens=20,
|
2024-04-16 08:54:57 +03:00
|
|
|
extra_body=dict(guided_regex=TEST_REGEX,
|
|
|
|
guided_decoding_backend=guided_decoding_backend))
|
2024-02-29 14:13:08 -08:00
|
|
|
ip1 = chat_completion.choices[0].message.content
|
|
|
|
assert ip1 is not None
|
|
|
|
assert re.fullmatch(TEST_REGEX, ip1) is not None
|
|
|
|
|
|
|
|
messages.append({"role": "assistant", "content": ip1})
|
|
|
|
messages.append({"role": "user", "content": "Give me a different one"})
|
|
|
|
chat_completion = await client.chat.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
messages=messages,
|
|
|
|
max_tokens=20,
|
2024-04-16 08:54:57 +03:00
|
|
|
extra_body=dict(guided_regex=TEST_REGEX,
|
|
|
|
guided_decoding_backend=guided_decoding_backend))
|
2024-02-29 14:13:08 -08:00
|
|
|
ip2 = chat_completion.choices[0].message.content
|
|
|
|
assert ip2 is not None
|
|
|
|
assert re.fullmatch(TEST_REGEX, ip2) is not None
|
|
|
|
assert ip1 != ip2
|
|
|
|
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
@pytest.mark.asyncio
|
2024-04-16 08:54:57 +03:00
|
|
|
@pytest.mark.parametrize("guided_decoding_backend",
|
|
|
|
["outlines", "lm-format-enforcer"])
|
|
|
|
async def test_guided_choice_completion(server, client: openai.AsyncOpenAI,
|
|
|
|
guided_decoding_backend: str):
|
2024-02-29 14:13:08 -08:00
|
|
|
completion = await client.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
prompt="The best language for type-safe systems programming is ",
|
|
|
|
n=2,
|
|
|
|
temperature=1.0,
|
|
|
|
max_tokens=10,
|
2024-04-16 08:54:57 +03:00
|
|
|
extra_body=dict(guided_choice=TEST_CHOICE,
|
|
|
|
guided_decoding_backend=guided_decoding_backend))
|
2024-02-29 14:13:08 -08:00
|
|
|
|
|
|
|
assert completion.id is not None
|
|
|
|
assert completion.choices is not None and len(completion.choices) == 2
|
|
|
|
for i in range(2):
|
|
|
|
assert completion.choices[i].text in TEST_CHOICE
|
|
|
|
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
@pytest.mark.asyncio
|
2024-04-16 08:54:57 +03:00
|
|
|
@pytest.mark.parametrize("guided_decoding_backend",
|
|
|
|
["outlines", "lm-format-enforcer"])
|
|
|
|
async def test_guided_choice_chat(server, client: openai.AsyncOpenAI,
|
|
|
|
guided_decoding_backend: str):
|
2024-02-29 14:13:08 -08:00
|
|
|
messages = [{
|
|
|
|
"role": "system",
|
|
|
|
"content": "you are a helpful assistant"
|
|
|
|
}, {
|
|
|
|
"role":
|
|
|
|
"user",
|
|
|
|
"content":
|
|
|
|
"The best language for type-safe systems programming is "
|
|
|
|
}]
|
|
|
|
chat_completion = await client.chat.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
messages=messages,
|
|
|
|
max_tokens=10,
|
2024-04-16 08:54:57 +03:00
|
|
|
extra_body=dict(guided_choice=TEST_CHOICE,
|
|
|
|
guided_decoding_backend=guided_decoding_backend))
|
2024-02-29 14:13:08 -08:00
|
|
|
choice1 = chat_completion.choices[0].message.content
|
|
|
|
assert choice1 in TEST_CHOICE
|
|
|
|
|
|
|
|
messages.append({"role": "assistant", "content": choice1})
|
|
|
|
messages.append({
|
|
|
|
"role": "user",
|
|
|
|
"content": "I disagree, pick another one"
|
|
|
|
})
|
|
|
|
chat_completion = await client.chat.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
messages=messages,
|
|
|
|
max_tokens=10,
|
2024-04-16 08:54:57 +03:00
|
|
|
extra_body=dict(guided_choice=TEST_CHOICE,
|
|
|
|
guided_decoding_backend=guided_decoding_backend))
|
2024-02-29 14:13:08 -08:00
|
|
|
choice2 = chat_completion.choices[0].message.content
|
|
|
|
assert choice2 in TEST_CHOICE
|
|
|
|
assert choice1 != choice2
|
|
|
|
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
@pytest.mark.asyncio
|
2024-04-16 08:54:57 +03:00
|
|
|
@pytest.mark.parametrize("guided_decoding_backend",
|
|
|
|
["outlines", "lm-format-enforcer"])
|
|
|
|
async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI,
|
|
|
|
guided_decoding_backend: str):
|
2024-02-29 14:13:08 -08:00
|
|
|
with pytest.raises(openai.BadRequestError):
|
|
|
|
_ = await client.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
prompt="Give an example JSON that fits this schema: 42",
|
2024-04-16 08:54:57 +03:00
|
|
|
extra_body=dict(guided_json=42,
|
|
|
|
guided_decoding_backend=guided_decoding_backend))
|
2024-02-29 14:13:08 -08:00
|
|
|
|
|
|
|
messages = [{
|
|
|
|
"role": "system",
|
|
|
|
"content": "you are a helpful assistant"
|
|
|
|
}, {
|
|
|
|
"role":
|
|
|
|
"user",
|
|
|
|
"content":
|
|
|
|
"The best language for type-safe systems programming is "
|
|
|
|
}]
|
|
|
|
with pytest.raises(openai.BadRequestError):
|
|
|
|
_ = await client.chat.completions.create(model=MODEL_NAME,
|
|
|
|
messages=messages,
|
|
|
|
extra_body=dict(guided_regex={
|
|
|
|
1: "Python",
|
|
|
|
2: "C++"
|
|
|
|
}))
|
|
|
|
|
|
|
|
with pytest.raises(openai.BadRequestError):
|
|
|
|
_ = await client.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
prompt="Give an example string that fits this regex",
|
|
|
|
extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA))
|
|
|
|
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
@pytest.mark.asyncio
|
2024-04-18 16:12:55 -05:00
|
|
|
@pytest.mark.parametrize("guided_decoding_backend",
|
|
|
|
["outlines", "lm-format-enforcer"])
|
|
|
|
async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
|
|
|
|
guided_decoding_backend: str):
|
|
|
|
messages = [{
|
|
|
|
"role": "system",
|
|
|
|
"content": "you are a helpful assistant"
|
|
|
|
}, {
|
|
|
|
"role":
|
|
|
|
"user",
|
|
|
|
"content":
|
|
|
|
"The best language for type-safe systems programming is "
|
|
|
|
}]
|
|
|
|
chat_completion = await client.chat.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
messages=messages,
|
|
|
|
max_tokens=10,
|
|
|
|
logprobs=True,
|
|
|
|
top_logprobs=5,
|
|
|
|
extra_body=dict(guided_choice=TEST_CHOICE,
|
|
|
|
guided_decoding_backend=guided_decoding_backend))
|
2024-05-30 11:52:14 +02:00
|
|
|
top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs
|
2024-04-18 16:12:55 -05:00
|
|
|
|
|
|
|
# -9999.0 is the minimum logprob returned by OpenAI
|
|
|
|
assert all(
|
2024-05-30 11:52:14 +02:00
|
|
|
isinstance(token.logprob, float) and token.logprob >= -9999.0
|
|
|
|
for token in top_logprobs)
|
2024-04-18 16:12:55 -05:00
|
|
|
|
|
|
|
|
2024-06-04 01:25:29 +02:00
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.parametrize("guided_decoding_backend",
|
|
|
|
["outlines", "lm-format-enforcer"])
|
|
|
|
async def test_named_tool_use(server, client: openai.AsyncOpenAI,
|
|
|
|
guided_decoding_backend: str):
|
|
|
|
messages = [{
|
|
|
|
"role": "system",
|
|
|
|
"content": "you are a helpful assistant"
|
|
|
|
}, {
|
|
|
|
"role":
|
|
|
|
"user",
|
|
|
|
"content":
|
|
|
|
f"Give an example JSON for an employee profile that "
|
|
|
|
f"fits this schema: {TEST_SCHEMA}"
|
|
|
|
}]
|
|
|
|
|
|
|
|
# non-streaming
|
|
|
|
|
|
|
|
chat_completion = await client.chat.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
messages=messages,
|
|
|
|
max_tokens=1000,
|
|
|
|
tools=[{
|
|
|
|
"type": "function",
|
|
|
|
"function": {
|
|
|
|
"name": "dummy_function_name",
|
|
|
|
"description": "This is a dummy function",
|
|
|
|
"parameters": TEST_SCHEMA
|
|
|
|
}
|
|
|
|
}],
|
|
|
|
tool_choice={
|
|
|
|
"type": "function",
|
|
|
|
"function": {
|
|
|
|
"name": "dummy_function_name"
|
|
|
|
}
|
|
|
|
})
|
|
|
|
message = chat_completion.choices[0].message
|
|
|
|
assert len(message.content) == 0
|
|
|
|
json_string = message.tool_calls[0].function.arguments
|
|
|
|
json1 = json.loads(json_string)
|
|
|
|
jsonschema.validate(instance=json1, schema=TEST_SCHEMA)
|
|
|
|
|
|
|
|
messages.append({"role": "assistant", "content": json_string})
|
|
|
|
messages.append({
|
|
|
|
"role":
|
|
|
|
"user",
|
|
|
|
"content":
|
|
|
|
"Give me another one with a different name and age"
|
|
|
|
})
|
|
|
|
|
|
|
|
# streaming
|
|
|
|
|
|
|
|
stream = await client.chat.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
messages=messages,
|
|
|
|
max_tokens=1000,
|
|
|
|
tools=[{
|
|
|
|
"type": "function",
|
|
|
|
"function": {
|
|
|
|
"name": "dummy_function_name",
|
|
|
|
"description": "This is a dummy function",
|
|
|
|
"parameters": TEST_SCHEMA
|
|
|
|
}
|
|
|
|
}],
|
|
|
|
tool_choice={
|
|
|
|
"type": "function",
|
|
|
|
"function": {
|
|
|
|
"name": "dummy_function_name"
|
|
|
|
}
|
|
|
|
},
|
|
|
|
stream=True)
|
|
|
|
|
|
|
|
output = []
|
|
|
|
finish_reason_count = 0
|
|
|
|
async for chunk in stream:
|
|
|
|
delta = chunk.choices[0].delta
|
|
|
|
if delta.role:
|
|
|
|
assert delta.role == "assistant"
|
|
|
|
assert delta.content is None or len(delta.content) == 0
|
|
|
|
if delta.tool_calls:
|
|
|
|
output.append(delta.tool_calls[0].function.arguments)
|
|
|
|
if chunk.choices[0].finish_reason is not None:
|
|
|
|
finish_reason_count += 1
|
|
|
|
# finish reason should only return in last block
|
|
|
|
assert finish_reason_count == 1
|
|
|
|
json2 = json.loads("".join(output))
|
|
|
|
jsonschema.validate(instance=json2, schema=TEST_SCHEMA)
|
|
|
|
assert json1["name"] != json2["name"]
|
|
|
|
assert json1["age"] != json2["age"]
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.parametrize("guided_decoding_backend", ["outlines"])
|
|
|
|
async def test_required_tool_use_not_yet_supported(
|
|
|
|
server, client: openai.AsyncOpenAI, guided_decoding_backend: str):
|
|
|
|
messages = [{
|
|
|
|
"role": "system",
|
|
|
|
"content": "you are a helpful assistant"
|
|
|
|
}, {
|
|
|
|
"role":
|
|
|
|
"user",
|
|
|
|
"content":
|
|
|
|
f"Give an example JSON for an employee profile that "
|
|
|
|
f"fits this schema: {TEST_SCHEMA}"
|
|
|
|
}]
|
|
|
|
|
|
|
|
with pytest.raises(openai.BadRequestError):
|
|
|
|
await client.chat.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
messages=messages,
|
|
|
|
max_tokens=1000,
|
|
|
|
tools=[{
|
|
|
|
"type": "function",
|
|
|
|
"function": {
|
|
|
|
"name": "dummy_function_name",
|
|
|
|
"description": "This is a dummy function",
|
|
|
|
"parameters": TEST_SCHEMA
|
|
|
|
}
|
|
|
|
}],
|
|
|
|
tool_choice="required")
|
|
|
|
|
|
|
|
with pytest.raises(openai.BadRequestError):
|
|
|
|
await client.chat.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
messages=messages,
|
|
|
|
max_tokens=1000,
|
|
|
|
tools=[{
|
|
|
|
"type": "function",
|
|
|
|
"function": {
|
|
|
|
"name": "dummy_function_name",
|
|
|
|
"description": "This is a dummy function",
|
|
|
|
"parameters": TEST_SCHEMA
|
|
|
|
}
|
|
|
|
}],
|
|
|
|
tool_choice="auto")
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.parametrize("guided_decoding_backend", ["outlines"])
|
|
|
|
async def test_inconsistent_tool_choice_and_tools(
|
|
|
|
server, client: openai.AsyncOpenAI, guided_decoding_backend: str):
|
|
|
|
messages = [{
|
|
|
|
"role": "system",
|
|
|
|
"content": "you are a helpful assistant"
|
|
|
|
}, {
|
|
|
|
"role":
|
|
|
|
"user",
|
|
|
|
"content":
|
|
|
|
f"Give an example JSON for an employee profile that "
|
|
|
|
f"fits this schema: {TEST_SCHEMA}"
|
|
|
|
}]
|
|
|
|
|
|
|
|
with pytest.raises(openai.BadRequestError):
|
|
|
|
await client.chat.completions.create(model=MODEL_NAME,
|
|
|
|
messages=messages,
|
|
|
|
max_tokens=1000,
|
|
|
|
tool_choice={
|
|
|
|
"type": "function",
|
|
|
|
"function": {
|
|
|
|
"name":
|
|
|
|
"dummy_function_name"
|
|
|
|
}
|
|
|
|
})
|
|
|
|
|
|
|
|
with pytest.raises(openai.BadRequestError):
|
|
|
|
await client.chat.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
messages=messages,
|
|
|
|
max_tokens=1000,
|
|
|
|
tools=[{
|
|
|
|
"type": "function",
|
|
|
|
"function": {
|
|
|
|
"name": "dummy_function_name",
|
|
|
|
"description": "This is a dummy function",
|
|
|
|
"parameters": TEST_SCHEMA
|
|
|
|
}
|
|
|
|
}],
|
|
|
|
tool_choice={
|
|
|
|
"type": "function",
|
|
|
|
"function": {
|
|
|
|
"name": "nondefined_function_name"
|
|
|
|
}
|
|
|
|
})
|
|
|
|
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
@pytest.mark.asyncio
|
2024-03-16 13:35:27 -07:00
|
|
|
async def test_response_format_json_object(server, client: openai.AsyncOpenAI):
|
2024-04-19 23:49:22 -04:00
|
|
|
for _ in range(2):
|
|
|
|
resp = await client.chat.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
messages=[{
|
|
|
|
"role":
|
|
|
|
"user",
|
|
|
|
"content": ('what is 1+1? please respond with a JSON object, '
|
|
|
|
'the format is {"result": 2}')
|
|
|
|
}],
|
|
|
|
response_format={"type": "json_object"})
|
|
|
|
|
|
|
|
content = resp.choices[0].message.content
|
|
|
|
loaded = json.loads(content)
|
|
|
|
assert loaded == {"result": 2}, loaded
|
2024-03-16 13:35:27 -07:00
|
|
|
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
@pytest.mark.asyncio
|
2024-04-27 13:08:24 +08:00
|
|
|
async def test_extra_fields(server, client: openai.AsyncOpenAI):
|
|
|
|
with pytest.raises(BadRequestError) as exc_info:
|
|
|
|
await client.chat.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
messages=[{
|
|
|
|
"role": "system",
|
|
|
|
"content": "You are a helpful assistant.",
|
|
|
|
"extra_field": "0",
|
|
|
|
}], # type: ignore
|
|
|
|
temperature=0,
|
|
|
|
seed=0)
|
|
|
|
|
|
|
|
assert "extra_forbidden" in exc_info.value.message
|
|
|
|
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
@pytest.mark.asyncio
|
2024-05-01 01:28:46 +02:00
|
|
|
async def test_complex_message_content(server, client: openai.AsyncOpenAI):
|
|
|
|
resp = await client.chat.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
messages=[{
|
|
|
|
"role":
|
|
|
|
"user",
|
|
|
|
"content": [{
|
|
|
|
"type":
|
|
|
|
"text",
|
|
|
|
"text":
|
|
|
|
"what is 1+1? please provide the result without any other text."
|
|
|
|
}]
|
|
|
|
}],
|
|
|
|
temperature=0,
|
|
|
|
seed=0)
|
|
|
|
content = resp.choices[0].message.content
|
|
|
|
assert content == "2"
|
|
|
|
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
@pytest.mark.asyncio
|
2024-05-16 05:58:46 +08:00
|
|
|
async def test_custom_role(server, client: openai.AsyncOpenAI):
|
|
|
|
# Not sure how the model handles custom roles so we just check that
|
|
|
|
# both string and complex message content are handled in the same way
|
|
|
|
|
|
|
|
resp1 = await client.chat.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
messages=[{
|
|
|
|
"role": "my-custom-role",
|
|
|
|
"content": "what is 1+1?",
|
|
|
|
}], # type: ignore
|
|
|
|
temperature=0,
|
|
|
|
seed=0)
|
|
|
|
|
|
|
|
resp2 = await client.chat.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
messages=[{
|
|
|
|
"role": "my-custom-role",
|
|
|
|
"content": [{
|
|
|
|
"type": "text",
|
|
|
|
"text": "what is 1+1?"
|
|
|
|
}]
|
|
|
|
}], # type: ignore
|
|
|
|
temperature=0,
|
|
|
|
seed=0)
|
|
|
|
|
|
|
|
content1 = resp1.choices[0].message.content
|
|
|
|
content2 = resp2.choices[0].message.content
|
|
|
|
assert content1 == content2
|
|
|
|
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
@pytest.mark.asyncio
|
2024-03-16 13:35:27 -07:00
|
|
|
async def test_guided_grammar(server, client: openai.AsyncOpenAI):
|
|
|
|
simple_sql_grammar = """
|
|
|
|
start: select_statement
|
|
|
|
|
|
|
|
select_statement: "SELECT" column "from" table "where" condition
|
|
|
|
|
|
|
|
column: "col_1" | "col_2"
|
|
|
|
table: "table_1" | "table_2"
|
|
|
|
condition: column "=" number
|
|
|
|
|
|
|
|
number: "1" | "2"
|
|
|
|
"""
|
|
|
|
|
|
|
|
completion = await client.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
prompt=("Generate a sql state that select col_1 from "
|
|
|
|
"table_1 where it is equals to 1"),
|
|
|
|
temperature=1.0,
|
|
|
|
max_tokens=500,
|
|
|
|
extra_body=dict(guided_grammar=simple_sql_grammar))
|
|
|
|
|
|
|
|
content = completion.choices[0].text
|
|
|
|
|
|
|
|
# use Lark to parse the output, and make sure it's a valid parse tree
|
|
|
|
from lark import Lark
|
|
|
|
parser = Lark(simple_sql_grammar)
|
|
|
|
parser.parse(content)
|
|
|
|
|
|
|
|
# remove spaces for comparison b/c we removed them in the grammar
|
|
|
|
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "")
|
|
|
|
|
|
|
|
assert content.strip() == ground_truth
|
|
|
|
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
@pytest.mark.asyncio
|
2024-04-11 15:15:50 -07:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
# first test base model, then test loras
|
|
|
|
"model_name",
|
|
|
|
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
|
|
|
)
|
2024-06-04 09:59:30 +09:00
|
|
|
@pytest.mark.parametrize("logprobs_arg", [1, 0])
|
2024-04-11 15:15:50 -07:00
|
|
|
async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
|
2024-06-04 09:59:30 +09:00
|
|
|
model_name: str, logprobs_arg: int):
|
2024-04-11 15:15:50 -07:00
|
|
|
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
|
|
|
# test using text and token IDs
|
|
|
|
for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]):
|
|
|
|
completion = await client.completions.create(model=model_name,
|
|
|
|
prompt=prompt,
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
echo=True,
|
2024-06-04 09:59:30 +09:00
|
|
|
logprobs=logprobs_arg)
|
2024-04-11 15:15:50 -07:00
|
|
|
|
|
|
|
prompt_text = tokenizer.decode(prompt) if isinstance(prompt,
|
|
|
|
list) else prompt
|
|
|
|
assert (completion.choices[0].text is not None
|
|
|
|
and re.search(r"^" + prompt_text, completion.choices[0].text))
|
|
|
|
logprobs = completion.choices[0].logprobs
|
|
|
|
assert logprobs is not None
|
|
|
|
assert len(logprobs.text_offset) > 5
|
|
|
|
assert (len(logprobs.token_logprobs) > 5
|
|
|
|
and logprobs.token_logprobs[0] is None)
|
|
|
|
assert (len(logprobs.top_logprobs) > 5
|
|
|
|
and logprobs.top_logprobs[0] is None)
|
2024-06-04 09:59:30 +09:00
|
|
|
for top_logprobs in logprobs.top_logprobs[1:]:
|
|
|
|
assert max(logprobs_arg,
|
|
|
|
1) <= len(top_logprobs) <= logprobs_arg + 1
|
2024-04-11 15:15:50 -07:00
|
|
|
assert len(logprobs.tokens) > 5
|
|
|
|
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
@pytest.mark.asyncio
|
2024-05-01 19:31:22 +00:00
|
|
|
async def test_long_seed(server, client: openai.AsyncOpenAI):
|
|
|
|
for seed in [
|
|
|
|
torch.iinfo(torch.long).min - 1,
|
|
|
|
torch.iinfo(torch.long).max + 1
|
|
|
|
]:
|
|
|
|
with pytest.raises(BadRequestError) as exc_info:
|
|
|
|
await client.chat.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
messages=[{
|
|
|
|
"role": "system",
|
|
|
|
"content": "You are a helpful assistant.",
|
|
|
|
}],
|
|
|
|
temperature=0,
|
|
|
|
seed=seed)
|
|
|
|
|
|
|
|
assert ("greater_than_equal" in exc_info.value.message
|
|
|
|
or "less_than_equal" in exc_info.value.message)
|
|
|
|
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
@pytest.mark.asyncio
|
2024-05-11 11:30:37 -07:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"model_name",
|
|
|
|
[EMBEDDING_MODEL_NAME],
|
|
|
|
)
|
|
|
|
async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI,
|
|
|
|
model_name: str):
|
|
|
|
input = [
|
|
|
|
"The chef prepared a delicious meal.",
|
|
|
|
]
|
|
|
|
|
|
|
|
# test single embedding
|
|
|
|
embeddings = await client.embeddings.create(
|
|
|
|
model=model_name,
|
|
|
|
input=input,
|
|
|
|
encoding_format="float",
|
|
|
|
)
|
|
|
|
assert embeddings.id is not None
|
|
|
|
assert embeddings.data is not None and len(embeddings.data) == 1
|
|
|
|
assert len(embeddings.data[0].embedding) == 4096
|
|
|
|
assert embeddings.usage.completion_tokens == 0
|
|
|
|
assert embeddings.usage.prompt_tokens == 9
|
|
|
|
assert embeddings.usage.total_tokens == 9
|
|
|
|
|
|
|
|
# test using token IDs
|
|
|
|
input = [1, 1, 1, 1, 1]
|
|
|
|
embeddings = await client.embeddings.create(
|
|
|
|
model=model_name,
|
|
|
|
input=input,
|
|
|
|
encoding_format="float",
|
|
|
|
)
|
|
|
|
assert embeddings.id is not None
|
|
|
|
assert embeddings.data is not None and len(embeddings.data) == 1
|
|
|
|
assert len(embeddings.data[0].embedding) == 4096
|
|
|
|
assert embeddings.usage.completion_tokens == 0
|
|
|
|
assert embeddings.usage.prompt_tokens == 5
|
|
|
|
assert embeddings.usage.total_tokens == 5
|
|
|
|
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
@pytest.mark.asyncio
|
2024-05-11 11:30:37 -07:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"model_name",
|
|
|
|
[EMBEDDING_MODEL_NAME],
|
|
|
|
)
|
|
|
|
async def test_batch_embedding(embedding_server, client: openai.AsyncOpenAI,
|
|
|
|
model_name: str):
|
|
|
|
# test List[str]
|
|
|
|
inputs = [
|
|
|
|
"The cat sat on the mat.", "A feline was resting on a rug.",
|
|
|
|
"Stars twinkle brightly in the night sky."
|
|
|
|
]
|
|
|
|
embeddings = await client.embeddings.create(
|
|
|
|
model=model_name,
|
|
|
|
input=inputs,
|
|
|
|
encoding_format="float",
|
|
|
|
)
|
|
|
|
assert embeddings.id is not None
|
|
|
|
assert embeddings.data is not None and len(embeddings.data) == 3
|
|
|
|
assert len(embeddings.data[0].embedding) == 4096
|
|
|
|
|
|
|
|
# test List[List[int]]
|
|
|
|
inputs = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
|
|
|
|
[25, 32, 64, 77]]
|
|
|
|
embeddings = await client.embeddings.create(
|
|
|
|
model=model_name,
|
|
|
|
input=inputs,
|
|
|
|
encoding_format="float",
|
|
|
|
)
|
|
|
|
assert embeddings.id is not None
|
|
|
|
assert embeddings.data is not None and len(embeddings.data) == 4
|
|
|
|
assert len(embeddings.data[0].embedding) == 4096
|
|
|
|
assert embeddings.usage.completion_tokens == 0
|
|
|
|
assert embeddings.usage.prompt_tokens == 17
|
|
|
|
assert embeddings.usage.total_tokens == 17
|
|
|
|
|
|
|
|
|
2024-01-17 05:33:14 +00:00
|
|
|
if __name__ == "__main__":
|
|
|
|
pytest.main([__file__])
|