2024-06-28 22:59:18 +08:00
|
|
|
# imports for guided decoding tests
|
|
|
|
import json
|
|
|
|
import re
|
|
|
|
from typing import List
|
|
|
|
|
|
|
|
import jsonschema
|
|
|
|
import openai # use the official client for correctness check
|
|
|
|
import pytest
|
|
|
|
# using Ray for overall ease of process management, parallel requests,
|
|
|
|
# and debugging.
|
|
|
|
import ray
|
|
|
|
import requests
|
|
|
|
# downloading lora to test lora requests
|
|
|
|
from huggingface_hub import snapshot_download
|
|
|
|
from openai import BadRequestError
|
|
|
|
|
|
|
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
|
|
|
|
2024-07-02 10:58:08 -07:00
|
|
|
from ...utils import VLLM_PATH, RemoteOpenAIServer
|
2024-06-28 22:59:18 +08:00
|
|
|
|
|
|
|
# any model with a chat template should work here
|
|
|
|
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
|
|
|
|
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
|
|
|
|
# generation quality here
|
|
|
|
LORA_NAME = "typeof/zephyr-7b-beta-lora"
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
|
|
def zephyr_lora_files():
|
|
|
|
return snapshot_download(repo_id=LORA_NAME)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
|
|
def ray_ctx():
|
2024-07-02 10:58:08 -07:00
|
|
|
ray.init(runtime_env={"working_dir": VLLM_PATH})
|
2024-06-28 22:59:18 +08:00
|
|
|
yield
|
|
|
|
ray.shutdown()
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
|
|
def server(zephyr_lora_files, ray_ctx):
|
|
|
|
return RemoteOpenAIServer([
|
|
|
|
"--model",
|
|
|
|
MODEL_NAME,
|
|
|
|
# use half precision for speed and memory savings in CI environment
|
|
|
|
"--dtype",
|
|
|
|
"bfloat16",
|
|
|
|
"--max-model-len",
|
|
|
|
"8192",
|
|
|
|
"--enforce-eager",
|
|
|
|
# lora config below
|
|
|
|
"--enable-lora",
|
|
|
|
"--lora-modules",
|
|
|
|
f"zephyr-lora={zephyr_lora_files}",
|
|
|
|
f"zephyr-lora2={zephyr_lora_files}",
|
|
|
|
"--max-lora-rank",
|
|
|
|
"64",
|
|
|
|
"--max-cpu-loras",
|
|
|
|
"2",
|
|
|
|
"--max-num-seqs",
|
|
|
|
"128",
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
|
|
def client(server):
|
|
|
|
return server.get_async_client()
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
# first test base model, then test loras
|
|
|
|
"model_name",
|
|
|
|
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
|
|
|
)
|
|
|
|
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
|
|
|
|
completion = await client.completions.create(model=model_name,
|
|
|
|
prompt="Hello, my name is",
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0)
|
|
|
|
|
|
|
|
assert completion.id is not None
|
|
|
|
assert completion.choices is not None and len(completion.choices) == 1
|
|
|
|
|
|
|
|
choice = completion.choices[0]
|
|
|
|
assert len(choice.text) >= 5
|
|
|
|
assert choice.finish_reason == "length"
|
|
|
|
assert completion.usage == openai.types.CompletionUsage(
|
|
|
|
completion_tokens=5, prompt_tokens=6, total_tokens=11)
|
|
|
|
|
|
|
|
# test using token IDs
|
|
|
|
completion = await client.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
prompt=[0, 0, 0, 0, 0],
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
)
|
|
|
|
assert len(completion.choices[0].text) >= 5
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
# first test base model, then test loras
|
|
|
|
"model_name",
|
|
|
|
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
|
|
|
)
|
|
|
|
async def test_no_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
|
|
|
# test using token IDs
|
|
|
|
completion = await client.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
prompt=[0, 0, 0, 0, 0],
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
logprobs=None,
|
|
|
|
)
|
|
|
|
choice = completion.choices[0]
|
|
|
|
assert choice.logprobs is None
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
# just test 1 lora hereafter
|
|
|
|
"model_name",
|
|
|
|
[MODEL_NAME, "zephyr-lora"],
|
|
|
|
)
|
|
|
|
async def test_zero_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
|
|
|
# test using token IDs
|
|
|
|
completion = await client.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
prompt=[0, 0, 0, 0, 0],
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
logprobs=0,
|
|
|
|
)
|
|
|
|
choice = completion.choices[0]
|
|
|
|
assert choice.logprobs is not None
|
|
|
|
assert choice.logprobs.token_logprobs is not None
|
|
|
|
assert choice.logprobs.top_logprobs is not None
|
|
|
|
assert len(choice.logprobs.top_logprobs[0]) == 1
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"model_name",
|
|
|
|
[MODEL_NAME, "zephyr-lora"],
|
|
|
|
)
|
|
|
|
async def test_some_logprobs(client: openai.AsyncOpenAI, model_name: str):
|
|
|
|
# test using token IDs
|
|
|
|
completion = await client.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
prompt=[0, 0, 0, 0, 0],
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
logprobs=5,
|
|
|
|
)
|
|
|
|
choice = completion.choices[0]
|
|
|
|
assert choice.logprobs is not None
|
|
|
|
assert choice.logprobs.token_logprobs is not None
|
|
|
|
assert choice.logprobs.top_logprobs is not None
|
|
|
|
assert 5 <= len(choice.logprobs.top_logprobs[0]) <= 6
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"model_name",
|
|
|
|
[MODEL_NAME, "zephyr-lora"],
|
|
|
|
)
|
|
|
|
async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
|
|
|
|
model_name: str):
|
|
|
|
|
|
|
|
with pytest.raises(
|
|
|
|
(openai.BadRequestError, openai.APIError)): # test using token IDs
|
|
|
|
await client.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
prompt=[0, 0, 0, 0, 0],
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
# vLLM has higher default max_logprobs (20 instead of 5) to support
|
|
|
|
# both Completion API and Chat Completion API
|
|
|
|
logprobs=21,
|
|
|
|
)
|
|
|
|
...
|
|
|
|
with pytest.raises(
|
|
|
|
(openai.BadRequestError, openai.APIError)): # test using token IDs
|
|
|
|
stream = await client.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
prompt=[0, 0, 0, 0, 0],
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
# vLLM has higher default max_logprobs (20 instead of 5) to support
|
|
|
|
# both Completion API and Chat Completion API
|
|
|
|
logprobs=30,
|
|
|
|
stream=True,
|
|
|
|
)
|
|
|
|
async for chunk in stream:
|
|
|
|
...
|
|
|
|
|
|
|
|
# the server should still work afterwards
|
|
|
|
completion = await client.completions.create(
|
|
|
|
model=model_name,
|
|
|
|
prompt=[0, 0, 0, 0, 0],
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
)
|
|
|
|
assert len(completion.choices[0].text) >= 0
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"model_name",
|
|
|
|
[MODEL_NAME, "zephyr-lora"],
|
|
|
|
)
|
|
|
|
async def test_completion_streaming(client: openai.AsyncOpenAI,
|
|
|
|
model_name: str):
|
|
|
|
prompt = "What is an LLM?"
|
|
|
|
|
|
|
|
single_completion = await client.completions.create(
|
|
|
|
model=model_name,
|
|
|
|
prompt=prompt,
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
)
|
|
|
|
single_output = single_completion.choices[0].text
|
|
|
|
stream = await client.completions.create(model=model_name,
|
|
|
|
prompt=prompt,
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
stream=True)
|
|
|
|
chunks: List[str] = []
|
|
|
|
finish_reason_count = 0
|
|
|
|
async for chunk in stream:
|
|
|
|
chunks.append(chunk.choices[0].text)
|
|
|
|
if chunk.choices[0].finish_reason is not None:
|
|
|
|
finish_reason_count += 1
|
|
|
|
# finish reason should only return in last block
|
|
|
|
assert finish_reason_count == 1
|
|
|
|
assert chunk.choices[0].finish_reason == "length"
|
|
|
|
assert chunk.choices[0].text
|
|
|
|
assert "".join(chunks) == single_output
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"model_name",
|
|
|
|
["HuggingFaceH4/zephyr-7b-beta", "zephyr-lora"],
|
|
|
|
)
|
|
|
|
async def test_completion_stream_options(client: openai.AsyncOpenAI,
|
|
|
|
model_name: str):
|
|
|
|
prompt = "What is the capital of France?"
|
|
|
|
|
2024-07-05 19:37:09 +02:00
|
|
|
# Test stream=True, stream_options=
|
|
|
|
# {"include_usage": False, "continuous_usage_stats": False}
|
|
|
|
stream = await client.completions.create(model=model_name,
|
|
|
|
prompt=prompt,
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
stream=True,
|
|
|
|
stream_options={
|
|
|
|
"include_usage": False,
|
|
|
|
"continuous_usage_stats":
|
|
|
|
False,
|
|
|
|
})
|
|
|
|
|
2024-06-28 22:59:18 +08:00
|
|
|
async for chunk in stream:
|
|
|
|
assert chunk.usage is None
|
|
|
|
|
2024-07-05 19:37:09 +02:00
|
|
|
# Test stream=True, stream_options=
|
|
|
|
# {"include_usage": False, "continuous_usage_stats": True}
|
|
|
|
stream = await client.completions.create(model=model_name,
|
|
|
|
prompt=prompt,
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
stream=True,
|
|
|
|
stream_options={
|
|
|
|
"include_usage": False,
|
|
|
|
"continuous_usage_stats":
|
|
|
|
True,
|
|
|
|
})
|
|
|
|
async for chunk in stream:
|
|
|
|
assert chunk.usage is None
|
|
|
|
|
|
|
|
# Test stream=True, stream_options=
|
|
|
|
# {"include_usage": True, "continuous_usage_stats": False}
|
|
|
|
stream = await client.completions.create(model=model_name,
|
|
|
|
prompt=prompt,
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
stream=True,
|
|
|
|
stream_options={
|
|
|
|
"include_usage": True,
|
|
|
|
"continuous_usage_stats":
|
|
|
|
False,
|
|
|
|
})
|
2024-06-28 22:59:18 +08:00
|
|
|
async for chunk in stream:
|
|
|
|
if chunk.choices[0].finish_reason is None:
|
|
|
|
assert chunk.usage is None
|
|
|
|
else:
|
|
|
|
assert chunk.usage is None
|
|
|
|
final_chunk = await stream.__anext__()
|
|
|
|
assert final_chunk.usage is not None
|
|
|
|
assert final_chunk.usage.prompt_tokens > 0
|
|
|
|
assert final_chunk.usage.completion_tokens > 0
|
|
|
|
assert final_chunk.usage.total_tokens == (
|
|
|
|
final_chunk.usage.prompt_tokens +
|
|
|
|
final_chunk.usage.completion_tokens)
|
|
|
|
assert final_chunk.choices == []
|
|
|
|
|
2024-07-05 19:37:09 +02:00
|
|
|
# Test stream=True, stream_options=
|
|
|
|
# {"include_usage": True, "continuous_usage_stats": True}
|
|
|
|
stream = await client.completions.create(model=model_name,
|
|
|
|
prompt=prompt,
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
stream=True,
|
|
|
|
stream_options={
|
|
|
|
"include_usage": True,
|
|
|
|
"continuous_usage_stats":
|
|
|
|
True,
|
|
|
|
})
|
|
|
|
async for chunk in stream:
|
|
|
|
assert chunk.usage is not None
|
|
|
|
assert chunk.usage.prompt_tokens > 0
|
|
|
|
assert chunk.usage.completion_tokens > 0
|
|
|
|
assert chunk.usage.total_tokens == (chunk.usage.prompt_tokens +
|
|
|
|
chunk.usage.completion_tokens)
|
|
|
|
if chunk.choices[0].finish_reason is not None:
|
|
|
|
final_chunk = await stream.__anext__()
|
|
|
|
assert final_chunk.usage is not None
|
|
|
|
assert final_chunk.usage.prompt_tokens > 0
|
|
|
|
assert final_chunk.usage.completion_tokens > 0
|
|
|
|
assert final_chunk.usage.total_tokens == (
|
|
|
|
final_chunk.usage.prompt_tokens +
|
|
|
|
final_chunk.usage.completion_tokens)
|
|
|
|
assert final_chunk.choices == []
|
|
|
|
|
|
|
|
# Test stream=False, stream_options=
|
|
|
|
# {"include_usage": None}
|
2024-06-28 22:59:18 +08:00
|
|
|
with pytest.raises(BadRequestError):
|
|
|
|
await client.completions.create(model=model_name,
|
|
|
|
prompt=prompt,
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
stream=False,
|
|
|
|
stream_options={"include_usage": None})
|
|
|
|
|
2024-07-05 19:37:09 +02:00
|
|
|
# Test stream=False, stream_options=
|
|
|
|
# {"include_usage": True}
|
2024-06-28 22:59:18 +08:00
|
|
|
with pytest.raises(BadRequestError):
|
|
|
|
await client.completions.create(model=model_name,
|
|
|
|
prompt=prompt,
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
stream=False,
|
|
|
|
stream_options={"include_usage": True})
|
|
|
|
|
2024-07-05 19:37:09 +02:00
|
|
|
# Test stream=False, stream_options=
|
|
|
|
# {"continuous_usage_stats": None}
|
|
|
|
with pytest.raises(BadRequestError):
|
|
|
|
await client.completions.create(
|
|
|
|
model=model_name,
|
|
|
|
prompt=prompt,
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
stream=False,
|
|
|
|
stream_options={"continuous_usage_stats": None})
|
|
|
|
|
|
|
|
# Test stream=False, stream_options=
|
|
|
|
# {"continuous_usage_stats": True}
|
|
|
|
with pytest.raises(BadRequestError):
|
|
|
|
await client.completions.create(
|
|
|
|
model=model_name,
|
|
|
|
prompt=prompt,
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
stream=False,
|
|
|
|
stream_options={"continuous_usage_stats": True})
|
|
|
|
|
2024-06-28 22:59:18 +08:00
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
# just test 1 lora hereafter
|
|
|
|
"model_name",
|
|
|
|
[MODEL_NAME, "zephyr-lora"],
|
|
|
|
)
|
|
|
|
async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
|
|
|
|
# test both text and token IDs
|
|
|
|
for prompts in (["Hello, my name is"] * 2, [[0, 0, 0, 0, 0]] * 2):
|
|
|
|
# test simple list
|
|
|
|
batch = await client.completions.create(
|
|
|
|
model=model_name,
|
|
|
|
prompt=prompts,
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
)
|
|
|
|
assert len(batch.choices) == 2
|
|
|
|
assert batch.choices[0].text == batch.choices[1].text
|
|
|
|
|
|
|
|
# test n = 2
|
|
|
|
batch = await client.completions.create(
|
|
|
|
model=model_name,
|
|
|
|
prompt=prompts,
|
|
|
|
n=2,
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
extra_body=dict(
|
|
|
|
# NOTE: this has to be true for n > 1 in vLLM, but not necessary
|
|
|
|
# for official client.
|
|
|
|
use_beam_search=True),
|
|
|
|
)
|
|
|
|
assert len(batch.choices) == 4
|
|
|
|
assert batch.choices[0].text != batch.choices[
|
|
|
|
1].text, "beam search should be different"
|
|
|
|
assert batch.choices[0].text == batch.choices[
|
|
|
|
2].text, "two copies of the same prompt should be the same"
|
|
|
|
assert batch.choices[1].text == batch.choices[
|
|
|
|
3].text, "two copies of the same prompt should be the same"
|
|
|
|
|
|
|
|
# test streaming
|
|
|
|
batch = await client.completions.create(
|
|
|
|
model=model_name,
|
|
|
|
prompt=prompts,
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
stream=True,
|
|
|
|
)
|
|
|
|
texts = [""] * 2
|
|
|
|
async for chunk in batch:
|
|
|
|
assert len(chunk.choices) == 1
|
|
|
|
choice = chunk.choices[0]
|
|
|
|
texts[choice.index] += choice.text
|
|
|
|
assert texts[0] == texts[1]
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_logits_bias(client: openai.AsyncOpenAI):
|
|
|
|
prompt = "Hello, my name is"
|
|
|
|
max_tokens = 5
|
|
|
|
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
|
|
|
|
|
|
|
# Test exclusive selection
|
|
|
|
token_id = 1000
|
|
|
|
completion = await client.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
prompt=prompt,
|
|
|
|
max_tokens=max_tokens,
|
|
|
|
temperature=0.0,
|
|
|
|
logit_bias={str(token_id): 100},
|
|
|
|
seed=42,
|
|
|
|
)
|
|
|
|
assert len(completion.choices[0].text) >= 5
|
|
|
|
response_tokens = tokenizer(completion.choices[0].text,
|
|
|
|
add_special_tokens=False)["input_ids"]
|
|
|
|
expected_tokens = tokenizer(tokenizer.decode([token_id] * 5),
|
|
|
|
add_special_tokens=False)["input_ids"]
|
|
|
|
assert all([
|
|
|
|
response == expected
|
|
|
|
for response, expected in zip(response_tokens, expected_tokens)
|
|
|
|
])
|
|
|
|
|
|
|
|
# Test ban
|
|
|
|
completion = await client.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
prompt=prompt,
|
|
|
|
max_tokens=max_tokens,
|
|
|
|
temperature=0.0,
|
|
|
|
)
|
|
|
|
response_tokens = tokenizer(completion.choices[0].text,
|
|
|
|
add_special_tokens=False)["input_ids"]
|
|
|
|
first_response = completion.choices[0].text
|
|
|
|
completion = await client.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
prompt=prompt,
|
|
|
|
max_tokens=max_tokens,
|
|
|
|
temperature=0.0,
|
|
|
|
logit_bias={str(token): -100
|
|
|
|
for token in response_tokens},
|
|
|
|
)
|
|
|
|
assert first_response != completion.choices[0].text
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.parametrize("guided_decoding_backend",
|
|
|
|
["outlines", "lm-format-enforcer"])
|
|
|
|
async def test_guided_json_completion(client: openai.AsyncOpenAI,
|
2024-07-12 12:55:39 -04:00
|
|
|
guided_decoding_backend: str,
|
|
|
|
sample_json_schema):
|
2024-06-28 22:59:18 +08:00
|
|
|
completion = await client.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
prompt=f"Give an example JSON for an employee profile "
|
2024-07-12 12:55:39 -04:00
|
|
|
f"that fits this schema: {sample_json_schema}",
|
2024-06-28 22:59:18 +08:00
|
|
|
n=3,
|
|
|
|
temperature=1.0,
|
|
|
|
max_tokens=500,
|
2024-07-12 12:55:39 -04:00
|
|
|
extra_body=dict(guided_json=sample_json_schema,
|
2024-06-28 22:59:18 +08:00
|
|
|
guided_decoding_backend=guided_decoding_backend))
|
|
|
|
|
|
|
|
assert completion.id is not None
|
|
|
|
assert len(completion.choices) == 3
|
|
|
|
for i in range(3):
|
|
|
|
output_json = json.loads(completion.choices[i].text)
|
2024-07-12 12:55:39 -04:00
|
|
|
jsonschema.validate(instance=output_json, schema=sample_json_schema)
|
2024-06-28 22:59:18 +08:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.parametrize("guided_decoding_backend",
|
|
|
|
["outlines", "lm-format-enforcer"])
|
|
|
|
async def test_guided_regex_completion(client: openai.AsyncOpenAI,
|
2024-07-12 12:55:39 -04:00
|
|
|
guided_decoding_backend: str,
|
|
|
|
sample_regex):
|
2024-06-28 22:59:18 +08:00
|
|
|
completion = await client.completions.create(
|
|
|
|
model=MODEL_NAME,
|
2024-07-12 12:55:39 -04:00
|
|
|
prompt=f"Give an example IPv4 address with this regex: {sample_regex}",
|
2024-06-28 22:59:18 +08:00
|
|
|
n=3,
|
|
|
|
temperature=1.0,
|
|
|
|
max_tokens=20,
|
2024-07-12 12:55:39 -04:00
|
|
|
extra_body=dict(guided_regex=sample_regex,
|
2024-06-28 22:59:18 +08:00
|
|
|
guided_decoding_backend=guided_decoding_backend))
|
|
|
|
|
|
|
|
assert completion.id is not None
|
|
|
|
assert len(completion.choices) == 3
|
|
|
|
for i in range(3):
|
2024-07-12 12:55:39 -04:00
|
|
|
assert re.fullmatch(sample_regex,
|
|
|
|
completion.choices[i].text) is not None
|
2024-06-28 22:59:18 +08:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.parametrize("guided_decoding_backend",
|
|
|
|
["outlines", "lm-format-enforcer"])
|
|
|
|
async def test_guided_choice_completion(client: openai.AsyncOpenAI,
|
2024-07-12 12:55:39 -04:00
|
|
|
guided_decoding_backend: str,
|
|
|
|
sample_guided_choice):
|
2024-06-28 22:59:18 +08:00
|
|
|
completion = await client.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
prompt="The best language for type-safe systems programming is ",
|
|
|
|
n=2,
|
|
|
|
temperature=1.0,
|
|
|
|
max_tokens=10,
|
2024-07-12 12:55:39 -04:00
|
|
|
extra_body=dict(guided_choice=sample_guided_choice,
|
2024-06-28 22:59:18 +08:00
|
|
|
guided_decoding_backend=guided_decoding_backend))
|
|
|
|
|
|
|
|
assert completion.id is not None
|
|
|
|
assert len(completion.choices) == 2
|
|
|
|
for i in range(2):
|
2024-07-12 12:55:39 -04:00
|
|
|
assert completion.choices[i].text in sample_guided_choice
|
2024-06-28 22:59:18 +08:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
2024-07-12 12:55:39 -04:00
|
|
|
async def test_guided_grammar(client: openai.AsyncOpenAI,
|
|
|
|
sample_sql_statements):
|
2024-06-28 22:59:18 +08:00
|
|
|
|
|
|
|
completion = await client.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
prompt=("Generate a sql state that select col_1 from "
|
|
|
|
"table_1 where it is equals to 1"),
|
|
|
|
temperature=1.0,
|
|
|
|
max_tokens=500,
|
2024-07-12 12:55:39 -04:00
|
|
|
extra_body=dict(guided_grammar=sample_sql_statements))
|
2024-06-28 22:59:18 +08:00
|
|
|
|
|
|
|
content = completion.choices[0].text
|
|
|
|
|
|
|
|
# use Lark to parse the output, and make sure it's a valid parse tree
|
|
|
|
from lark import Lark
|
2024-07-12 12:55:39 -04:00
|
|
|
parser = Lark(sample_sql_statements)
|
2024-06-28 22:59:18 +08:00
|
|
|
parser.parse(content)
|
|
|
|
|
|
|
|
# remove spaces for comparison b/c we removed them in the grammar
|
|
|
|
ground_truth = "SELECT col_1 from table_1 where col_1 = 1".replace(" ", "")
|
|
|
|
|
|
|
|
assert content.strip() == ground_truth
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
# first test base model, then test loras
|
|
|
|
"model_name",
|
|
|
|
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
|
|
|
)
|
|
|
|
@pytest.mark.parametrize("logprobs_arg", [1, 0])
|
|
|
|
async def test_echo_logprob_completion(client: openai.AsyncOpenAI,
|
|
|
|
model_name: str, logprobs_arg: int):
|
|
|
|
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
|
|
|
# test using text and token IDs
|
|
|
|
for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]):
|
|
|
|
completion = await client.completions.create(model=model_name,
|
|
|
|
prompt=prompt,
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
echo=True,
|
|
|
|
logprobs=logprobs_arg)
|
|
|
|
|
|
|
|
prompt_text = tokenizer.decode(prompt) if isinstance(prompt,
|
|
|
|
list) else prompt
|
|
|
|
assert re.search(r"^" + prompt_text, completion.choices[0].text)
|
|
|
|
logprobs = completion.choices[0].logprobs
|
|
|
|
assert logprobs is not None
|
|
|
|
assert len(logprobs.text_offset) > 5
|
|
|
|
assert (len(logprobs.token_logprobs) > 5
|
|
|
|
and logprobs.token_logprobs[0] is None)
|
|
|
|
assert (len(logprobs.top_logprobs) > 5
|
|
|
|
and logprobs.top_logprobs[0] is None)
|
|
|
|
for top_logprobs in logprobs.top_logprobs[1:]:
|
|
|
|
assert max(logprobs_arg,
|
|
|
|
1) <= len(top_logprobs) <= logprobs_arg + 1
|
|
|
|
assert len(logprobs.tokens) > 5
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.parametrize("guided_decoding_backend",
|
|
|
|
["outlines", "lm-format-enforcer"])
|
|
|
|
async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
|
2024-07-12 12:55:39 -04:00
|
|
|
guided_decoding_backend: str,
|
|
|
|
sample_json_schema, sample_regex):
|
2024-06-28 22:59:18 +08:00
|
|
|
with pytest.raises(openai.BadRequestError):
|
|
|
|
_ = await client.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
prompt="Give an example JSON that fits this schema: 42",
|
|
|
|
extra_body=dict(guided_json=42,
|
|
|
|
guided_decoding_backend=guided_decoding_backend))
|
|
|
|
|
|
|
|
with pytest.raises(openai.BadRequestError):
|
|
|
|
_ = await client.completions.create(
|
|
|
|
model=MODEL_NAME,
|
|
|
|
prompt="Give an example string that fits this regex",
|
2024-07-12 12:55:39 -04:00
|
|
|
extra_body=dict(guided_regex=sample_regex,
|
|
|
|
guided_json=sample_json_schema))
|
2024-06-28 22:59:18 +08:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"model_name",
|
|
|
|
[MODEL_NAME],
|
|
|
|
)
|
|
|
|
async def test_tokenize(client: openai.AsyncOpenAI, model_name: str):
|
2024-07-03 13:31:25 +09:00
|
|
|
base_url = str(client.base_url)[:-3].strip("/")
|
2024-06-28 22:59:18 +08:00
|
|
|
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME, tokenizer_mode="fast")
|
|
|
|
|
|
|
|
for add_special in [False, True]:
|
|
|
|
prompt = "This is a test prompt."
|
|
|
|
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)
|
|
|
|
|
|
|
|
response = requests.post(base_url + "/tokenize",
|
|
|
|
json={
|
|
|
|
"add_special_tokens": add_special,
|
|
|
|
"model": model_name,
|
|
|
|
"prompt": prompt
|
|
|
|
})
|
|
|
|
response.raise_for_status()
|
|
|
|
assert response.json() == {
|
|
|
|
"tokens": tokens,
|
|
|
|
"count": len(tokens),
|
|
|
|
"max_model_len": 8192
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"model_name",
|
|
|
|
[MODEL_NAME],
|
|
|
|
)
|
|
|
|
async def test_detokenize(client: openai.AsyncOpenAI, model_name: str):
|
|
|
|
base_url = str(client.base_url)[:-3]
|
|
|
|
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME, tokenizer_mode="fast")
|
|
|
|
|
|
|
|
prompt = "This is a test prompt."
|
|
|
|
tokens = tokenizer.encode(prompt, add_special_tokens=False)
|
|
|
|
|
|
|
|
response = requests.post(base_url + "detokenize",
|
|
|
|
json={
|
|
|
|
"model": model_name,
|
|
|
|
"tokens": tokens
|
|
|
|
})
|
|
|
|
response.raise_for_status()
|
|
|
|
assert response.json() == {"prompt": prompt}
|