2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
2024-09-28 11:52:46 -04:00
|
|
|
import asyncio
|
2024-10-31 22:20:17 -07:00
|
|
|
import copy
|
2024-08-01 22:03:12 -07:00
|
|
|
import functools
|
2024-05-13 22:50:09 +08:00
|
|
|
import os
|
2024-08-01 22:03:12 -07:00
|
|
|
import signal
|
2024-05-13 22:50:09 +08:00
|
|
|
import subprocess
|
|
|
|
import sys
|
2025-03-17 19:33:35 +08:00
|
|
|
import tempfile
|
2024-05-13 22:50:09 +08:00
|
|
|
import time
|
2024-05-29 04:29:31 +08:00
|
|
|
import warnings
|
2025-03-17 19:33:35 +08:00
|
|
|
from contextlib import contextmanager, suppress
|
2024-06-30 12:58:49 +08:00
|
|
|
from pathlib import Path
|
2025-03-17 19:33:35 +08:00
|
|
|
from typing import Any, Callable, Literal, Optional, Union
|
2024-05-13 22:50:09 +08:00
|
|
|
|
2025-03-17 19:33:35 +08:00
|
|
|
import cloudpickle
|
2024-06-14 02:21:53 +08:00
|
|
|
import openai
|
2024-09-14 01:20:06 +08:00
|
|
|
import pytest
|
2024-05-13 22:50:09 +08:00
|
|
|
import requests
|
2024-10-31 22:20:17 -07:00
|
|
|
import torch
|
2024-11-07 05:42:40 -03:00
|
|
|
import torch.nn.functional as F
|
2024-08-29 22:19:08 -04:00
|
|
|
from openai.types.completion import Completion
|
2024-10-31 22:20:17 -07:00
|
|
|
from typing_extensions import ParamSpec
|
2024-05-13 22:50:09 +08:00
|
|
|
|
2024-10-09 00:38:40 -07:00
|
|
|
import vllm.envs as envs
|
2024-08-29 22:19:08 -04:00
|
|
|
from tests.models.utils import TextTextLogprobs
|
2024-05-13 22:50:09 +08:00
|
|
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
|
|
|
init_distributed_environment)
|
2024-08-26 13:31:10 +08:00
|
|
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
2024-06-14 02:21:53 +08:00
|
|
|
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
2024-09-07 16:02:39 +08:00
|
|
|
from vllm.model_executor.model_loader.loader import get_model_loader
|
2024-08-13 19:27:46 -07:00
|
|
|
from vllm.platforms import current_platform
|
2024-10-03 19:56:58 -07:00
|
|
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
2024-09-29 10:50:51 +08:00
|
|
|
from vllm.utils import (FlexibleArgumentParser, GB_bytes,
|
2024-10-28 12:07:00 +08:00
|
|
|
cuda_device_count_stateless, get_open_port)
|
2024-06-19 13:57:12 -07:00
|
|
|
|
2024-08-13 19:27:46 -07:00
|
|
|
if current_platform.is_rocm():
|
2024-06-25 17:56:15 -05:00
|
|
|
from amdsmi import (amdsmi_get_gpu_vram_usage,
|
|
|
|
amdsmi_get_processor_handles, amdsmi_init,
|
|
|
|
amdsmi_shut_down)
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def _nvml():
|
|
|
|
try:
|
|
|
|
amdsmi_init()
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
amdsmi_shut_down()
|
2024-08-13 19:27:46 -07:00
|
|
|
elif current_platform.is_cuda():
|
2025-02-09 15:00:00 +08:00
|
|
|
from vllm.third_party.pynvml import (nvmlDeviceGetHandleByIndex,
|
|
|
|
nvmlDeviceGetMemoryInfo, nvmlInit,
|
|
|
|
nvmlShutdown)
|
2024-06-25 17:56:15 -05:00
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def _nvml():
|
|
|
|
try:
|
|
|
|
nvmlInit()
|
|
|
|
yield
|
|
|
|
finally:
|
|
|
|
nvmlShutdown()
|
2024-08-13 19:27:46 -07:00
|
|
|
else:
|
|
|
|
|
|
|
|
@contextmanager
|
|
|
|
def _nvml():
|
|
|
|
yield
|
2024-06-25 17:56:15 -05:00
|
|
|
|
2024-05-13 22:50:09 +08:00
|
|
|
|
2024-06-30 12:58:49 +08:00
|
|
|
VLLM_PATH = Path(__file__).parent.parent
|
|
|
|
"""Path to root of the vLLM repository."""
|
2024-05-13 22:50:09 +08:00
|
|
|
|
|
|
|
|
2024-06-14 02:21:53 +08:00
|
|
|
class RemoteOpenAIServer:
|
|
|
|
DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key
|
2024-08-20 17:12:44 -07:00
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
model: str,
|
2025-03-03 01:34:51 +00:00
|
|
|
vllm_serve_args: list[str],
|
2024-08-20 17:12:44 -07:00
|
|
|
*,
|
2025-03-03 01:34:51 +00:00
|
|
|
env_dict: Optional[dict[str, str]] = None,
|
2025-03-08 19:30:09 +08:00
|
|
|
seed: Optional[int] = 0,
|
2024-08-20 17:12:44 -07:00
|
|
|
auto_port: bool = True,
|
|
|
|
max_wait_seconds: Optional[float] = None) -> None:
|
2024-06-14 02:21:53 +08:00
|
|
|
if auto_port:
|
2024-08-26 13:31:10 +08:00
|
|
|
if "-p" in vllm_serve_args or "--port" in vllm_serve_args:
|
|
|
|
raise ValueError("You have manually specified the port "
|
2024-06-14 02:21:53 +08:00
|
|
|
"when `auto_port=True`.")
|
|
|
|
|
2024-08-26 13:31:10 +08:00
|
|
|
# Don't mutate the input args
|
|
|
|
vllm_serve_args = vllm_serve_args + [
|
|
|
|
"--port", str(get_open_port())
|
|
|
|
]
|
2025-03-08 19:30:09 +08:00
|
|
|
if seed is not None:
|
|
|
|
if "--seed" in vllm_serve_args:
|
|
|
|
raise ValueError("You have manually specified the seed "
|
|
|
|
f"when `seed={seed}`.")
|
|
|
|
|
|
|
|
vllm_serve_args = vllm_serve_args + ["--seed", str(seed)]
|
2024-06-14 02:21:53 +08:00
|
|
|
|
2024-07-14 15:36:43 -07:00
|
|
|
parser = FlexibleArgumentParser(
|
|
|
|
description="vLLM's remote OpenAI server.")
|
|
|
|
parser = make_arg_parser(parser)
|
2024-08-26 13:31:10 +08:00
|
|
|
args = parser.parse_args(["--model", model, *vllm_serve_args])
|
2024-06-14 02:21:53 +08:00
|
|
|
self.host = str(args.host or 'localhost')
|
|
|
|
self.port = int(args.port)
|
|
|
|
|
2025-04-02 15:37:19 +01:00
|
|
|
self.show_hidden_metrics = \
|
|
|
|
args.show_hidden_metrics_for_version is not None
|
|
|
|
|
2024-08-26 13:31:10 +08:00
|
|
|
# download the model before starting the server to avoid timeout
|
|
|
|
is_local = os.path.isdir(model)
|
|
|
|
if not is_local:
|
|
|
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
2024-09-07 16:02:39 +08:00
|
|
|
model_config = engine_args.create_model_config()
|
|
|
|
load_config = engine_args.create_load_config()
|
|
|
|
|
|
|
|
model_loader = get_model_loader(load_config)
|
|
|
|
model_loader.download_model(model_config)
|
2024-08-26 13:31:10 +08:00
|
|
|
|
2024-07-12 21:51:48 -07:00
|
|
|
env = os.environ.copy()
|
|
|
|
# the current process might initialize cuda,
|
|
|
|
# to be safe, we should use spawn method
|
|
|
|
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
|
2024-08-02 13:55:40 -07:00
|
|
|
if env_dict is not None:
|
|
|
|
env.update(env_dict)
|
2024-08-26 13:31:10 +08:00
|
|
|
self.proc = subprocess.Popen(
|
|
|
|
["vllm", "serve", model, *vllm_serve_args],
|
|
|
|
env=env,
|
|
|
|
stdout=sys.stdout,
|
|
|
|
stderr=sys.stderr,
|
|
|
|
)
|
2024-08-20 17:12:44 -07:00
|
|
|
max_wait_seconds = max_wait_seconds or 240
|
2024-07-12 21:51:48 -07:00
|
|
|
self._wait_for_server(url=self.url_for("health"),
|
2024-08-20 17:12:44 -07:00
|
|
|
timeout=max_wait_seconds)
|
2024-07-12 21:51:48 -07:00
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
return self
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
|
|
self.proc.terminate()
|
2024-08-02 13:55:40 -07:00
|
|
|
try:
|
2024-09-18 09:56:58 -04:00
|
|
|
self.proc.wait(8)
|
2024-08-02 13:55:40 -07:00
|
|
|
except subprocess.TimeoutExpired:
|
|
|
|
# force kill if needed
|
|
|
|
self.proc.kill()
|
2024-07-12 21:51:48 -07:00
|
|
|
|
|
|
|
def _wait_for_server(self, *, url: str, timeout: float):
|
|
|
|
# run health check
|
|
|
|
start = time.time()
|
|
|
|
while True:
|
|
|
|
try:
|
|
|
|
if requests.get(url).status_code == 200:
|
|
|
|
break
|
2024-10-31 05:52:05 +08:00
|
|
|
except Exception:
|
|
|
|
# this exception can only be raised by requests.get,
|
|
|
|
# which means the server is not ready yet.
|
|
|
|
# the stack trace is not useful, so we suppress it
|
|
|
|
# by using `raise from None`.
|
2024-07-12 21:51:48 -07:00
|
|
|
result = self.proc.poll()
|
|
|
|
if result is not None and result != 0:
|
2024-10-31 05:52:05 +08:00
|
|
|
raise RuntimeError("Server exited unexpectedly.") from None
|
2024-07-12 21:51:48 -07:00
|
|
|
|
|
|
|
time.sleep(0.5)
|
|
|
|
if time.time() - start > timeout:
|
|
|
|
raise RuntimeError(
|
2024-10-31 05:52:05 +08:00
|
|
|
"Server failed to start in time.") from None
|
2024-06-14 02:21:53 +08:00
|
|
|
|
|
|
|
@property
|
|
|
|
def url_root(self) -> str:
|
|
|
|
return f"http://{self.host}:{self.port}"
|
|
|
|
|
|
|
|
def url_for(self, *parts: str) -> str:
|
|
|
|
return self.url_root + "/" + "/".join(parts)
|
|
|
|
|
2025-01-09 22:47:29 +08:00
|
|
|
def get_client(self, **kwargs):
|
|
|
|
if "timeout" not in kwargs:
|
|
|
|
kwargs["timeout"] = 600
|
2024-06-14 02:21:53 +08:00
|
|
|
return openai.OpenAI(
|
|
|
|
base_url=self.url_for("v1"),
|
|
|
|
api_key=self.DUMMY_API_KEY,
|
2025-01-09 22:47:29 +08:00
|
|
|
max_retries=0,
|
|
|
|
**kwargs,
|
2024-06-14 02:21:53 +08:00
|
|
|
)
|
|
|
|
|
2024-12-17 13:26:32 -07:00
|
|
|
def get_async_client(self, **kwargs):
|
2025-01-09 22:47:29 +08:00
|
|
|
if "timeout" not in kwargs:
|
|
|
|
kwargs["timeout"] = 600
|
2024-12-17 13:26:32 -07:00
|
|
|
return openai.AsyncOpenAI(base_url=self.url_for("v1"),
|
|
|
|
api_key=self.DUMMY_API_KEY,
|
|
|
|
max_retries=0,
|
|
|
|
**kwargs)
|
2024-05-13 22:50:09 +08:00
|
|
|
|
|
|
|
|
2024-10-06 16:35:27 +08:00
|
|
|
def _test_completion(
|
|
|
|
client: openai.OpenAI,
|
|
|
|
model: str,
|
|
|
|
prompt: str,
|
2025-03-03 01:34:51 +00:00
|
|
|
token_ids: list[int],
|
2024-10-06 16:35:27 +08:00
|
|
|
):
|
|
|
|
results = []
|
|
|
|
|
|
|
|
# test with text prompt
|
|
|
|
completion = client.completions.create(model=model,
|
|
|
|
prompt=prompt,
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0)
|
|
|
|
|
|
|
|
results.append({
|
|
|
|
"test": "single_completion",
|
|
|
|
"text": completion.choices[0].text,
|
|
|
|
"finish_reason": completion.choices[0].finish_reason,
|
|
|
|
"usage": completion.usage,
|
|
|
|
})
|
|
|
|
|
|
|
|
# test using token IDs
|
|
|
|
completion = client.completions.create(
|
|
|
|
model=model,
|
|
|
|
prompt=token_ids,
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
)
|
|
|
|
|
|
|
|
results.append({
|
|
|
|
"test": "token_ids",
|
|
|
|
"text": completion.choices[0].text,
|
|
|
|
"finish_reason": completion.choices[0].finish_reason,
|
|
|
|
"usage": completion.usage,
|
|
|
|
})
|
|
|
|
|
|
|
|
# test seeded random sampling
|
|
|
|
completion = client.completions.create(model=model,
|
|
|
|
prompt=prompt,
|
|
|
|
max_tokens=5,
|
|
|
|
seed=33,
|
|
|
|
temperature=1.0)
|
|
|
|
|
|
|
|
results.append({
|
|
|
|
"test": "seeded_sampling",
|
|
|
|
"text": completion.choices[0].text,
|
|
|
|
"finish_reason": completion.choices[0].finish_reason,
|
|
|
|
"usage": completion.usage,
|
|
|
|
})
|
|
|
|
|
|
|
|
# test seeded random sampling with multiple prompts
|
|
|
|
completion = client.completions.create(model=model,
|
|
|
|
prompt=[prompt, prompt],
|
|
|
|
max_tokens=5,
|
|
|
|
seed=33,
|
|
|
|
temperature=1.0)
|
|
|
|
|
|
|
|
results.append({
|
|
|
|
"test":
|
|
|
|
"seeded_sampling",
|
|
|
|
"text": [choice.text for choice in completion.choices],
|
|
|
|
"finish_reason":
|
|
|
|
[choice.finish_reason for choice in completion.choices],
|
|
|
|
"usage":
|
|
|
|
completion.usage,
|
|
|
|
})
|
|
|
|
|
|
|
|
# test simple list
|
|
|
|
batch = client.completions.create(
|
|
|
|
model=model,
|
|
|
|
prompt=[prompt, prompt],
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
)
|
|
|
|
|
|
|
|
results.append({
|
|
|
|
"test": "simple_list",
|
|
|
|
"text0": batch.choices[0].text,
|
|
|
|
"text1": batch.choices[1].text,
|
|
|
|
})
|
|
|
|
|
|
|
|
# test streaming
|
|
|
|
batch = client.completions.create(
|
|
|
|
model=model,
|
|
|
|
prompt=[prompt, prompt],
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0,
|
|
|
|
stream=True,
|
|
|
|
)
|
|
|
|
|
|
|
|
texts = [""] * 2
|
|
|
|
for chunk in batch:
|
|
|
|
assert len(chunk.choices) == 1
|
|
|
|
choice = chunk.choices[0]
|
|
|
|
texts[choice.index] += choice.text
|
|
|
|
|
|
|
|
results.append({
|
|
|
|
"test": "streaming",
|
|
|
|
"texts": texts,
|
|
|
|
})
|
|
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
2024-10-31 22:20:17 -07:00
|
|
|
def _test_completion_close(
|
|
|
|
client: openai.OpenAI,
|
|
|
|
model: str,
|
|
|
|
prompt: str,
|
|
|
|
):
|
|
|
|
results = []
|
|
|
|
|
|
|
|
# test with text prompt
|
|
|
|
completion = client.completions.create(model=model,
|
|
|
|
prompt=prompt,
|
|
|
|
max_tokens=1,
|
|
|
|
logprobs=5,
|
|
|
|
temperature=0.0)
|
|
|
|
|
2025-02-24 07:33:20 -08:00
|
|
|
logprobs = completion.choices[0].logprobs.top_logprobs[0]
|
|
|
|
logprobs = {k: round(v, 2) for k, v in logprobs.items()}
|
2024-10-31 22:20:17 -07:00
|
|
|
|
|
|
|
results.append({
|
|
|
|
"test": "completion_close",
|
2025-02-24 07:33:20 -08:00
|
|
|
"logprobs": logprobs,
|
2024-10-31 22:20:17 -07:00
|
|
|
})
|
|
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
2025-03-28 10:29:32 -06:00
|
|
|
def _test_chat(
|
|
|
|
client: openai.OpenAI,
|
|
|
|
model: str,
|
|
|
|
prompt: str,
|
|
|
|
):
|
|
|
|
results = []
|
|
|
|
|
|
|
|
messages = [{
|
|
|
|
"role": "user",
|
|
|
|
"content": [{
|
|
|
|
"type": "text",
|
|
|
|
"text": prompt
|
|
|
|
}]
|
|
|
|
}]
|
|
|
|
|
|
|
|
# test with text prompt
|
|
|
|
chat_response = client.chat.completions.create(model=model,
|
|
|
|
messages=messages,
|
|
|
|
max_tokens=5,
|
|
|
|
temperature=0.0)
|
|
|
|
|
|
|
|
results.append({
|
|
|
|
"test": "completion_close",
|
|
|
|
"text": chat_response.choices[0].message.content,
|
|
|
|
"finish_reason": chat_response.choices[0].finish_reason,
|
|
|
|
"usage": chat_response.usage,
|
|
|
|
})
|
|
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
2024-10-06 16:35:27 +08:00
|
|
|
def _test_embeddings(
|
|
|
|
client: openai.OpenAI,
|
|
|
|
model: str,
|
|
|
|
text: str,
|
|
|
|
):
|
|
|
|
results = []
|
|
|
|
|
|
|
|
# test with text input
|
|
|
|
embeddings = client.embeddings.create(
|
|
|
|
model=model,
|
|
|
|
input=text,
|
|
|
|
encoding_format="float",
|
|
|
|
)
|
|
|
|
|
|
|
|
results.append({
|
|
|
|
"test": "single_embedding",
|
|
|
|
"embedding": embeddings.data[0].embedding,
|
|
|
|
"usage": embeddings.usage,
|
|
|
|
})
|
|
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
2024-10-31 22:20:17 -07:00
|
|
|
def _test_image_text(
|
|
|
|
client: openai.OpenAI,
|
|
|
|
model_name: str,
|
|
|
|
image_url: str,
|
|
|
|
):
|
|
|
|
results = []
|
|
|
|
|
|
|
|
# test pure text input
|
|
|
|
messages = [{
|
|
|
|
"role":
|
|
|
|
"user",
|
|
|
|
"content": [
|
|
|
|
{
|
|
|
|
"type": "text",
|
|
|
|
"text": "How do you feel today?"
|
|
|
|
},
|
|
|
|
],
|
|
|
|
}]
|
|
|
|
|
|
|
|
chat_completion = client.chat.completions.create(model=model_name,
|
|
|
|
messages=messages,
|
|
|
|
temperature=0.0,
|
|
|
|
max_tokens=1,
|
|
|
|
logprobs=True,
|
|
|
|
top_logprobs=5)
|
|
|
|
top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs
|
|
|
|
|
|
|
|
for x in top_logprobs:
|
|
|
|
x.logprob = round(x.logprob, 2)
|
|
|
|
|
|
|
|
results.append({
|
|
|
|
"test": "pure_text",
|
|
|
|
"logprobs": top_logprobs,
|
|
|
|
})
|
|
|
|
|
|
|
|
messages = [{
|
|
|
|
"role":
|
|
|
|
"user",
|
|
|
|
"content": [
|
|
|
|
{
|
|
|
|
"type": "image_url",
|
|
|
|
"image_url": {
|
|
|
|
"url": image_url
|
|
|
|
}
|
|
|
|
},
|
|
|
|
{
|
|
|
|
"type": "text",
|
|
|
|
"text": "What's in this image?"
|
|
|
|
},
|
|
|
|
],
|
|
|
|
}]
|
|
|
|
|
|
|
|
chat_completion = client.chat.completions.create(model=model_name,
|
|
|
|
messages=messages,
|
|
|
|
temperature=0.0,
|
|
|
|
max_tokens=1,
|
|
|
|
logprobs=True,
|
|
|
|
top_logprobs=5)
|
|
|
|
top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs
|
|
|
|
|
|
|
|
results.append({
|
|
|
|
"test": "text_image",
|
|
|
|
"logprobs": top_logprobs,
|
|
|
|
})
|
|
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
2024-08-02 13:55:40 -07:00
|
|
|
def compare_two_settings(model: str,
|
2025-03-03 01:34:51 +00:00
|
|
|
arg1: list[str],
|
|
|
|
arg2: list[str],
|
|
|
|
env1: Optional[dict[str, str]] = None,
|
|
|
|
env2: Optional[dict[str, str]] = None,
|
2024-10-06 16:35:27 +08:00
|
|
|
*,
|
2024-10-31 22:20:17 -07:00
|
|
|
method: str = "generate",
|
2024-08-20 17:12:44 -07:00
|
|
|
max_wait_seconds: Optional[float] = None) -> None:
|
2024-07-18 16:41:06 -07:00
|
|
|
"""
|
2024-08-02 13:55:40 -07:00
|
|
|
Launch API server with two different sets of arguments/environments
|
|
|
|
and compare the results of the API calls.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model: The model to test.
|
|
|
|
arg1: The first set of arguments to pass to the API server.
|
|
|
|
arg2: The second set of arguments to pass to the API server.
|
|
|
|
env1: The first set of environment variables to pass to the API server.
|
|
|
|
env2: The second set of environment variables to pass to the API server.
|
2024-07-18 16:41:06 -07:00
|
|
|
"""
|
|
|
|
|
2024-10-07 19:51:49 -07:00
|
|
|
compare_all_settings(
|
|
|
|
model,
|
|
|
|
[arg1, arg2],
|
|
|
|
[env1, env2],
|
|
|
|
method=method,
|
|
|
|
max_wait_seconds=max_wait_seconds,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def compare_all_settings(model: str,
|
2025-03-03 01:34:51 +00:00
|
|
|
all_args: list[list[str]],
|
|
|
|
all_envs: list[Optional[dict[str, str]]],
|
2024-10-07 19:51:49 -07:00
|
|
|
*,
|
2024-10-31 22:20:17 -07:00
|
|
|
method: str = "generate",
|
2024-10-07 19:51:49 -07:00
|
|
|
max_wait_seconds: Optional[float] = None) -> None:
|
|
|
|
"""
|
|
|
|
Launch API server with several different sets of arguments/environments
|
|
|
|
and compare the results of the API calls with the first set of arguments.
|
|
|
|
Args:
|
|
|
|
model: The model to test.
|
|
|
|
all_args: A list of argument lists to pass to the API server.
|
|
|
|
all_envs: A list of environment dictionaries to pass to the API server.
|
|
|
|
"""
|
|
|
|
|
2024-10-03 19:56:58 -07:00
|
|
|
trust_remote_code = False
|
2024-10-07 19:51:49 -07:00
|
|
|
for args in all_args:
|
2024-10-03 19:56:58 -07:00
|
|
|
if "--trust-remote-code" in args:
|
|
|
|
trust_remote_code = True
|
|
|
|
break
|
|
|
|
|
|
|
|
tokenizer_mode = "auto"
|
2024-10-07 19:51:49 -07:00
|
|
|
for args in all_args:
|
2024-10-03 19:56:58 -07:00
|
|
|
if "--tokenizer-mode" in args:
|
|
|
|
tokenizer_mode = args[args.index("--tokenizer-mode") + 1]
|
|
|
|
break
|
|
|
|
|
|
|
|
tokenizer = get_tokenizer(
|
|
|
|
model,
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
tokenizer_mode=tokenizer_mode,
|
|
|
|
)
|
2024-07-18 16:41:06 -07:00
|
|
|
|
2024-10-09 00:38:40 -07:00
|
|
|
can_force_load_format = True
|
|
|
|
|
|
|
|
for args in all_args:
|
|
|
|
if "--load-format" in args:
|
|
|
|
can_force_load_format = False
|
|
|
|
break
|
|
|
|
|
2024-07-18 16:41:06 -07:00
|
|
|
prompt = "Hello, my name is"
|
2024-10-03 19:56:58 -07:00
|
|
|
token_ids = tokenizer(prompt).input_ids
|
2025-03-03 01:34:51 +00:00
|
|
|
ref_results: list = []
|
2024-10-07 19:51:49 -07:00
|
|
|
for i, (args, env) in enumerate(zip(all_args, all_envs)):
|
2024-10-09 00:38:40 -07:00
|
|
|
if can_force_load_format:
|
|
|
|
# we are comparing the results and
|
|
|
|
# usually we don't need real weights.
|
|
|
|
# we force to use dummy weights by default,
|
|
|
|
# and it should work for most of the cases.
|
|
|
|
# if not, we can use VLLM_TEST_FORCE_LOAD_FORMAT
|
|
|
|
# environment variable to force the load format,
|
|
|
|
# e.g. in quantization tests.
|
|
|
|
args = args + ["--load-format", envs.VLLM_TEST_FORCE_LOAD_FORMAT]
|
2025-03-03 01:34:51 +00:00
|
|
|
compare_results: list = []
|
2024-10-07 19:51:49 -07:00
|
|
|
results = ref_results if i == 0 else compare_results
|
2024-08-20 17:12:44 -07:00
|
|
|
with RemoteOpenAIServer(model,
|
|
|
|
args,
|
|
|
|
env_dict=env,
|
|
|
|
max_wait_seconds=max_wait_seconds) as server:
|
2024-07-18 16:41:06 -07:00
|
|
|
client = server.get_client()
|
|
|
|
|
|
|
|
# test models list
|
|
|
|
models = client.models.list()
|
|
|
|
models = models.data
|
|
|
|
served_model = models[0]
|
|
|
|
results.append({
|
|
|
|
"test": "models_list",
|
|
|
|
"id": served_model.id,
|
|
|
|
"root": served_model.root,
|
|
|
|
})
|
|
|
|
|
2024-10-06 16:35:27 +08:00
|
|
|
if method == "generate":
|
|
|
|
results += _test_completion(client, model, prompt, token_ids)
|
2024-10-31 22:20:17 -07:00
|
|
|
elif method == "generate_close":
|
|
|
|
results += _test_completion_close(client, model, prompt)
|
2025-03-28 10:29:32 -06:00
|
|
|
elif method == "generate_chat":
|
|
|
|
results += _test_chat(client, model, prompt)
|
2024-10-31 22:20:17 -07:00
|
|
|
elif method == "generate_with_image":
|
|
|
|
results += _test_image_text(
|
|
|
|
client, model,
|
|
|
|
"https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png"
|
|
|
|
)
|
2024-10-06 16:35:27 +08:00
|
|
|
elif method == "encode":
|
|
|
|
results += _test_embeddings(client, model, prompt)
|
|
|
|
else:
|
2024-10-31 22:20:17 -07:00
|
|
|
raise ValueError(f"Unknown method: {method}")
|
2024-07-18 16:41:06 -07:00
|
|
|
|
2024-10-07 19:51:49 -07:00
|
|
|
if i > 0:
|
|
|
|
# if any setting fails, raise an error early
|
|
|
|
ref_args = all_args[0]
|
|
|
|
ref_envs = all_envs[0]
|
|
|
|
compare_args = all_args[i]
|
|
|
|
compare_envs = all_envs[i]
|
|
|
|
for ref_result, compare_result in zip(ref_results,
|
|
|
|
compare_results):
|
2024-10-31 22:20:17 -07:00
|
|
|
ref_result = copy.deepcopy(ref_result)
|
|
|
|
compare_result = copy.deepcopy(compare_result)
|
|
|
|
if "embedding" in ref_result and method == "encode":
|
2024-11-07 05:42:40 -03:00
|
|
|
sim = F.cosine_similarity(
|
|
|
|
torch.tensor(ref_result["embedding"]),
|
|
|
|
torch.tensor(compare_result["embedding"]),
|
|
|
|
dim=0,
|
|
|
|
)
|
|
|
|
assert sim >= 0.999, (
|
2024-10-31 22:20:17 -07:00
|
|
|
f"Embedding for {model=} are not the same.\n"
|
2024-11-07 05:42:40 -03:00
|
|
|
f"cosine_similarity={sim}\n")
|
2024-10-31 22:20:17 -07:00
|
|
|
del ref_result["embedding"]
|
|
|
|
del compare_result["embedding"]
|
2024-10-07 19:51:49 -07:00
|
|
|
assert ref_result == compare_result, (
|
|
|
|
f"Results for {model=} are not the same.\n"
|
|
|
|
f"{ref_args=} {ref_envs=}\n"
|
|
|
|
f"{compare_args=} {compare_envs=}\n"
|
|
|
|
f"{ref_result=}\n"
|
|
|
|
f"{compare_result=}\n")
|
2024-07-18 16:41:06 -07:00
|
|
|
|
|
|
|
|
2024-05-13 22:50:09 +08:00
|
|
|
def init_test_distributed_environment(
|
|
|
|
tp_size: int,
|
|
|
|
pp_size: int,
|
|
|
|
rank: int,
|
|
|
|
distributed_init_port: str,
|
|
|
|
local_rank: int = -1,
|
|
|
|
) -> None:
|
|
|
|
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
|
|
|
init_distributed_environment(
|
|
|
|
world_size=pp_size * tp_size,
|
|
|
|
rank=rank,
|
|
|
|
distributed_init_method=distributed_init_method,
|
|
|
|
local_rank=local_rank)
|
|
|
|
ensure_model_parallel_initialized(tp_size, pp_size)
|
|
|
|
|
|
|
|
|
2024-06-23 17:42:28 -04:00
|
|
|
def multi_process_parallel(
|
2025-03-17 11:35:57 +08:00
|
|
|
monkeypatch: pytest.MonkeyPatch,
|
2024-05-13 22:50:09 +08:00
|
|
|
tp_size: int,
|
|
|
|
pp_size: int,
|
2024-06-30 12:58:49 +08:00
|
|
|
test_target: Any,
|
2024-05-13 22:50:09 +08:00
|
|
|
) -> None:
|
2024-08-13 19:27:46 -07:00
|
|
|
import ray
|
|
|
|
|
2024-05-13 22:50:09 +08:00
|
|
|
# Using ray helps debugging the error when it failed
|
|
|
|
# as compared to multiprocessing.
|
2024-06-30 12:58:49 +08:00
|
|
|
# NOTE: We need to set working_dir for distributed tests,
|
|
|
|
# otherwise we may get import errors on ray workers
|
2025-04-01 07:49:12 +02:00
|
|
|
# NOTE: Force ray not to use gitignore file as excluding, otherwise
|
|
|
|
# it will not move .so files to working dir.
|
|
|
|
# So we have to manually add some of large directories
|
|
|
|
os.environ["RAY_RUNTIME_ENV_IGNORE_GITIGNORE"] = "1"
|
|
|
|
ray.init(
|
|
|
|
runtime_env={
|
|
|
|
"working_dir": VLLM_PATH,
|
|
|
|
"excludes":
|
|
|
|
["build", ".git", "cmake-build-*", "shellcheck", "dist"]
|
|
|
|
})
|
2024-05-13 22:50:09 +08:00
|
|
|
|
|
|
|
distributed_init_port = get_open_port()
|
|
|
|
refs = []
|
|
|
|
for rank in range(tp_size * pp_size):
|
|
|
|
refs.append(
|
2025-03-17 11:35:57 +08:00
|
|
|
test_target.remote(
|
|
|
|
monkeypatch,
|
|
|
|
tp_size,
|
|
|
|
pp_size,
|
|
|
|
rank,
|
|
|
|
distributed_init_port,
|
|
|
|
), )
|
2024-05-13 22:50:09 +08:00
|
|
|
ray.get(refs)
|
|
|
|
|
|
|
|
ray.shutdown()
|
2024-05-29 04:29:31 +08:00
|
|
|
|
|
|
|
|
|
|
|
@contextmanager
|
2025-03-03 01:34:51 +00:00
|
|
|
def error_on_warning(category: type[Warning] = Warning):
|
2024-05-29 04:29:31 +08:00
|
|
|
"""
|
|
|
|
Within the scope of this context manager, tests will fail if any warning
|
2024-10-19 02:31:58 +08:00
|
|
|
of the given category is emitted.
|
2024-05-29 04:29:31 +08:00
|
|
|
"""
|
|
|
|
with warnings.catch_warnings():
|
2024-10-19 02:31:58 +08:00
|
|
|
warnings.filterwarnings("error", category=category)
|
2024-05-29 04:29:31 +08:00
|
|
|
|
|
|
|
yield
|
2024-06-19 13:57:12 -07:00
|
|
|
|
|
|
|
|
2024-09-13 11:06:28 +08:00
|
|
|
def get_physical_device_indices(devices):
|
|
|
|
visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
|
|
|
|
if visible_devices is None:
|
|
|
|
return devices
|
|
|
|
|
|
|
|
visible_indices = [int(x) for x in visible_devices.split(",")]
|
|
|
|
index_mapping = {i: physical for i, physical in enumerate(visible_indices)}
|
|
|
|
return [index_mapping[i] for i in devices if i in index_mapping]
|
|
|
|
|
|
|
|
|
2024-06-25 17:56:15 -05:00
|
|
|
@_nvml()
|
2025-03-03 01:34:51 +00:00
|
|
|
def wait_for_gpu_memory_to_clear(devices: list[int],
|
2024-06-19 13:57:12 -07:00
|
|
|
threshold_bytes: int,
|
|
|
|
timeout_s: float = 120) -> None:
|
|
|
|
# Use nvml instead of pytorch to reduce measurement error from torch cuda
|
|
|
|
# context.
|
2024-09-13 11:06:28 +08:00
|
|
|
devices = get_physical_device_indices(devices)
|
2024-06-19 13:57:12 -07:00
|
|
|
start_time = time.time()
|
|
|
|
while True:
|
2025-03-03 01:34:51 +00:00
|
|
|
output: dict[int, str] = {}
|
|
|
|
output_raw: dict[int, float] = {}
|
2024-06-19 13:57:12 -07:00
|
|
|
for device in devices:
|
2024-10-28 12:07:00 +08:00
|
|
|
if current_platform.is_rocm():
|
2024-06-25 17:56:15 -05:00
|
|
|
dev_handle = amdsmi_get_processor_handles()[device]
|
|
|
|
mem_info = amdsmi_get_gpu_vram_usage(dev_handle)
|
|
|
|
gb_used = mem_info["vram_used"] / 2**10
|
|
|
|
else:
|
|
|
|
dev_handle = nvmlDeviceGetHandleByIndex(device)
|
|
|
|
mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
|
|
|
|
gb_used = mem_info.used / 2**30
|
2024-06-19 13:57:12 -07:00
|
|
|
output_raw[device] = gb_used
|
|
|
|
output[device] = f'{gb_used:.02f}'
|
|
|
|
|
|
|
|
print('gpu memory used (GB): ', end='')
|
|
|
|
for k, v in output.items():
|
|
|
|
print(f'{k}={v}; ', end='')
|
|
|
|
print('')
|
|
|
|
|
|
|
|
dur_s = time.time() - start_time
|
|
|
|
if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()):
|
|
|
|
print(f'Done waiting for free GPU memory on devices {devices=} '
|
|
|
|
f'({threshold_bytes/2**30=}) {dur_s=:.02f}')
|
|
|
|
break
|
|
|
|
|
|
|
|
if dur_s >= timeout_s:
|
|
|
|
raise ValueError(f'Memory of devices {devices=} not free after '
|
|
|
|
f'{dur_s=:.02f} ({threshold_bytes/2**30=})')
|
|
|
|
|
|
|
|
time.sleep(5)
|
2024-08-01 22:03:12 -07:00
|
|
|
|
|
|
|
|
2024-08-13 09:20:20 +08:00
|
|
|
_P = ParamSpec("_P")
|
|
|
|
|
|
|
|
|
|
|
|
def fork_new_process_for_each_test(
|
|
|
|
f: Callable[_P, None]) -> Callable[_P, None]:
|
2024-08-03 10:44:53 -07:00
|
|
|
"""Decorator to fork a new process for each test function.
|
|
|
|
See https://github.com/vllm-project/vllm/issues/7053 for more details.
|
|
|
|
"""
|
2024-08-01 22:03:12 -07:00
|
|
|
|
|
|
|
@functools.wraps(f)
|
2024-08-13 09:20:20 +08:00
|
|
|
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
|
2024-08-01 22:03:12 -07:00
|
|
|
# Make the process the leader of its own process group
|
|
|
|
# to avoid sending SIGTERM to the parent process
|
|
|
|
os.setpgrp()
|
|
|
|
from _pytest.outcomes import Skipped
|
|
|
|
pid = os.fork()
|
2024-08-16 20:49:30 -07:00
|
|
|
print(f"Fork a new process to run a test {pid}")
|
2024-08-01 22:03:12 -07:00
|
|
|
if pid == 0:
|
|
|
|
try:
|
|
|
|
f(*args, **kwargs)
|
|
|
|
except Skipped as e:
|
|
|
|
# convert Skipped to exit code 0
|
|
|
|
print(str(e))
|
|
|
|
os._exit(0)
|
|
|
|
except Exception:
|
|
|
|
import traceback
|
|
|
|
traceback.print_exc()
|
|
|
|
os._exit(1)
|
|
|
|
else:
|
|
|
|
os._exit(0)
|
|
|
|
else:
|
|
|
|
pgid = os.getpgid(pid)
|
|
|
|
_pid, _exitcode = os.waitpid(pid, 0)
|
|
|
|
# ignore SIGTERM signal itself
|
2024-08-16 20:49:30 -07:00
|
|
|
old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN)
|
2024-08-01 22:03:12 -07:00
|
|
|
# kill all child processes
|
|
|
|
os.killpg(pgid, signal.SIGTERM)
|
|
|
|
# restore the signal handler
|
2024-08-16 20:49:30 -07:00
|
|
|
signal.signal(signal.SIGTERM, old_signal_handler)
|
2024-08-01 22:03:12 -07:00
|
|
|
assert _exitcode == 0, (f"function {f} failed when called with"
|
|
|
|
f" args {args} and kwargs {kwargs}")
|
|
|
|
|
|
|
|
return wrapper
|
2024-08-29 22:19:08 -04:00
|
|
|
|
|
|
|
|
2025-03-17 19:33:35 +08:00
|
|
|
def spawn_new_process_for_each_test(
|
|
|
|
f: Callable[_P, None]) -> Callable[_P, None]:
|
|
|
|
"""Decorator to spawn a new process for each test function.
|
|
|
|
"""
|
|
|
|
|
|
|
|
@functools.wraps(f)
|
|
|
|
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
|
|
|
|
# Check if we're already in a subprocess
|
|
|
|
if os.environ.get('RUNNING_IN_SUBPROCESS') == '1':
|
|
|
|
# If we are, just run the function directly
|
|
|
|
return f(*args, **kwargs)
|
|
|
|
|
|
|
|
import torch.multiprocessing as mp
|
|
|
|
with suppress(RuntimeError):
|
|
|
|
mp.set_start_method('spawn')
|
|
|
|
|
|
|
|
# Get the module
|
|
|
|
module_name = f.__module__
|
|
|
|
|
|
|
|
# Create a process with environment variable set
|
|
|
|
env = os.environ.copy()
|
|
|
|
env['RUNNING_IN_SUBPROCESS'] = '1'
|
|
|
|
|
|
|
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
|
|
output_filepath = os.path.join(tempdir, "new_process.tmp")
|
|
|
|
|
|
|
|
# `cloudpickle` allows pickling complex functions directly
|
|
|
|
input_bytes = cloudpickle.dumps((f, output_filepath))
|
|
|
|
|
|
|
|
cmd = [sys.executable, "-m", f"{module_name}"]
|
|
|
|
|
|
|
|
returned = subprocess.run(cmd,
|
|
|
|
input=input_bytes,
|
|
|
|
capture_output=True,
|
|
|
|
env=env)
|
|
|
|
|
|
|
|
# check if the subprocess is successful
|
|
|
|
try:
|
|
|
|
returned.check_returncode()
|
|
|
|
except Exception as e:
|
|
|
|
# wrap raised exception to provide more information
|
|
|
|
raise RuntimeError(f"Error raised in subprocess:\n"
|
|
|
|
f"{returned.stderr.decode()}") from e
|
|
|
|
|
|
|
|
return wrapper
|
|
|
|
|
|
|
|
|
|
|
|
def create_new_process_for_each_test(
|
|
|
|
method: Optional[Literal["spawn", "fork"]] = None
|
|
|
|
) -> Callable[[Callable[_P, None]], Callable[_P, None]]:
|
|
|
|
"""Creates a decorator that runs each test function in a new process.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
method: The process creation method. Can be either "spawn" or "fork".
|
|
|
|
If not specified,
|
|
|
|
it defaults to "spawn" on ROCm platforms and "fork" otherwise.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A decorator to run test functions in separate processes.
|
|
|
|
"""
|
|
|
|
if method is None:
|
|
|
|
method = "spawn" if current_platform.is_rocm() else "fork"
|
|
|
|
|
|
|
|
assert method in ["spawn",
|
|
|
|
"fork"], "Method must be either 'spawn' or 'fork'"
|
|
|
|
|
|
|
|
if method == "fork":
|
|
|
|
return fork_new_process_for_each_test
|
|
|
|
|
|
|
|
return spawn_new_process_for_each_test
|
|
|
|
|
|
|
|
|
2024-10-30 10:32:17 -06:00
|
|
|
def large_gpu_mark(min_gb: int) -> pytest.MarkDecorator:
|
2024-12-12 06:18:16 +08:00
|
|
|
"""
|
|
|
|
Get a pytest mark, which skips the test if the GPU doesn't meet
|
|
|
|
a minimum memory requirement in GB.
|
2025-03-17 11:35:57 +08:00
|
|
|
|
2024-12-12 06:18:16 +08:00
|
|
|
This can be leveraged via `@large_gpu_test` to skip tests in environments
|
|
|
|
without enough resources, or called when filtering tests to run directly.
|
2024-09-29 10:50:51 +08:00
|
|
|
"""
|
|
|
|
try:
|
2025-03-22 17:06:39 -04:00
|
|
|
if current_platform.is_cpu():
|
2024-09-29 10:50:51 +08:00
|
|
|
memory_gb = 0
|
|
|
|
else:
|
|
|
|
memory_gb = current_platform.get_device_total_memory() / GB_bytes
|
|
|
|
except Exception as e:
|
|
|
|
warnings.warn(
|
|
|
|
f"An error occurred when finding the available memory: {e}",
|
|
|
|
stacklevel=2,
|
|
|
|
)
|
|
|
|
memory_gb = 0
|
|
|
|
|
2024-10-30 10:32:17 -06:00
|
|
|
return pytest.mark.skipif(
|
2024-09-29 10:50:51 +08:00
|
|
|
memory_gb < min_gb,
|
2024-11-06 16:50:37 +08:00
|
|
|
reason=f"Need at least {min_gb}GB GPU memory to run the test.",
|
2024-09-29 10:50:51 +08:00
|
|
|
)
|
|
|
|
|
2024-10-30 10:32:17 -06:00
|
|
|
|
|
|
|
def large_gpu_test(*, min_gb: int):
|
|
|
|
"""
|
|
|
|
Decorate a test to be skipped if no GPU is available or it does not have
|
|
|
|
sufficient memory.
|
|
|
|
|
|
|
|
Currently, the CI machine uses L4 GPU which has 24 GB VRAM.
|
|
|
|
"""
|
2024-12-12 06:18:16 +08:00
|
|
|
mark = large_gpu_mark(min_gb)
|
2024-10-30 10:32:17 -06:00
|
|
|
|
2024-09-29 10:50:51 +08:00
|
|
|
def wrapper(f: Callable[_P, None]) -> Callable[_P, None]:
|
2024-12-12 06:18:16 +08:00
|
|
|
return mark(f)
|
2024-09-29 10:50:51 +08:00
|
|
|
|
|
|
|
return wrapper
|
|
|
|
|
|
|
|
|
2024-12-12 06:18:16 +08:00
|
|
|
def multi_gpu_marks(*, num_gpus: int):
|
|
|
|
"""Get a collection of pytest marks to apply for `@multi_gpu_test`."""
|
|
|
|
test_selector = pytest.mark.distributed(num_gpus=num_gpus)
|
2024-09-14 01:20:06 +08:00
|
|
|
test_skipif = pytest.mark.skipif(
|
|
|
|
cuda_device_count_stateless() < num_gpus,
|
|
|
|
reason=f"Need at least {num_gpus} GPUs to run the test.",
|
|
|
|
)
|
|
|
|
|
2024-12-12 06:18:16 +08:00
|
|
|
return [test_selector, test_skipif]
|
|
|
|
|
|
|
|
|
|
|
|
def multi_gpu_test(*, num_gpus: int):
|
|
|
|
"""
|
|
|
|
Decorate a test to be run only when multiple GPUs are available.
|
|
|
|
"""
|
|
|
|
marks = multi_gpu_marks(num_gpus=num_gpus)
|
|
|
|
|
2024-09-14 01:20:06 +08:00
|
|
|
def wrapper(f: Callable[_P, None]) -> Callable[_P, None]:
|
2025-03-17 19:33:35 +08:00
|
|
|
func = create_new_process_for_each_test()(f)
|
2024-12-12 06:18:16 +08:00
|
|
|
for mark in reversed(marks):
|
|
|
|
func = mark(func)
|
|
|
|
|
|
|
|
return func
|
2024-09-14 01:20:06 +08:00
|
|
|
|
|
|
|
return wrapper
|
|
|
|
|
|
|
|
|
2024-08-29 22:19:08 -04:00
|
|
|
async def completions_with_server_args(
|
2025-03-03 01:34:51 +00:00
|
|
|
prompts: list[str],
|
2024-08-29 22:19:08 -04:00
|
|
|
model_name: str,
|
2025-03-03 01:34:51 +00:00
|
|
|
server_cli_args: list[str],
|
2024-08-29 22:19:08 -04:00
|
|
|
num_logprobs: Optional[int],
|
|
|
|
max_wait_seconds: int = 240,
|
2024-09-28 11:52:46 -04:00
|
|
|
max_tokens: Union[int, list] = 5,
|
2025-03-03 01:34:51 +00:00
|
|
|
) -> list[Completion]:
|
2024-08-29 22:19:08 -04:00
|
|
|
'''Construct a remote OpenAI server, obtain an async client to the
|
|
|
|
server & invoke the completions API to obtain completions.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
prompts: test prompts
|
|
|
|
model_name: model to spin up on the vLLM server
|
|
|
|
server_cli_args: CLI args for starting the server
|
|
|
|
num_logprobs: Number of logprobs to report (or `None`)
|
|
|
|
max_wait_seconds: timeout interval for bringing up server.
|
|
|
|
Default: 240sec
|
2024-09-28 11:52:46 -04:00
|
|
|
max_tokens: max_tokens value for each of the given input prompts.
|
|
|
|
if only one max_token value is given, the same value is used
|
|
|
|
for all the prompts.
|
2024-08-29 22:19:08 -04:00
|
|
|
|
|
|
|
Returns:
|
|
|
|
OpenAI Completion instance
|
|
|
|
'''
|
|
|
|
|
2024-09-28 11:52:46 -04:00
|
|
|
if isinstance(max_tokens, int):
|
|
|
|
max_tokens = [max_tokens] * len(prompts)
|
|
|
|
|
|
|
|
assert len(max_tokens) == len(prompts)
|
|
|
|
|
2024-08-29 22:19:08 -04:00
|
|
|
outputs = None
|
|
|
|
with RemoteOpenAIServer(model_name,
|
|
|
|
server_cli_args,
|
|
|
|
max_wait_seconds=max_wait_seconds) as server:
|
|
|
|
client = server.get_async_client()
|
2024-09-28 11:52:46 -04:00
|
|
|
outputs = [ client.completions.create(model=model_name,
|
|
|
|
prompt=[p],
|
|
|
|
temperature=0,
|
|
|
|
stream=False,
|
|
|
|
max_tokens=max_tok,
|
|
|
|
logprobs=num_logprobs) \
|
|
|
|
for p, max_tok in zip(prompts, max_tokens) ]
|
|
|
|
outputs = await asyncio.gather(*outputs)
|
|
|
|
|
2024-09-18 11:38:43 -04:00
|
|
|
assert outputs is not None, "Completion API call failed."
|
2024-08-29 22:19:08 -04:00
|
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
2025-03-03 01:34:51 +00:00
|
|
|
def get_client_text_generations(completions: list[Completion]) -> list[str]:
|
2024-08-29 22:19:08 -04:00
|
|
|
'''Extract generated tokens from the output of a
|
|
|
|
request made to an Open-AI-protocol completions endpoint.
|
|
|
|
'''
|
2024-09-28 11:52:46 -04:00
|
|
|
assert all([len(x.choices) == 1 for x in completions])
|
|
|
|
return [x.choices[0].text for x in completions]
|
2024-08-29 22:19:08 -04:00
|
|
|
|
|
|
|
|
|
|
|
def get_client_text_logprob_generations(
|
2025-03-03 01:34:51 +00:00
|
|
|
completions: list[Completion]) -> list[TextTextLogprobs]:
|
2024-08-29 22:19:08 -04:00
|
|
|
'''Operates on the output of a request made to an Open-AI-protocol
|
|
|
|
completions endpoint; obtains top-rank logprobs for each token in
|
|
|
|
each :class:`SequenceGroup`
|
|
|
|
'''
|
|
|
|
text_generations = get_client_text_generations(completions)
|
|
|
|
text = ''.join(text_generations)
|
|
|
|
return [(text_generations, text,
|
|
|
|
(None if x.logprobs is None else x.logprobs.top_logprobs))
|
2024-09-28 11:52:46 -04:00
|
|
|
for completion in completions for x in completion.choices]
|