329 lines
11 KiB
Python
329 lines
11 KiB
Python
import asyncio
|
|
import time
|
|
from itertools import cycle
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
import pytest
|
|
import ray
|
|
import torch
|
|
|
|
from vllm.utils import is_hip
|
|
|
|
if (not is_hip()):
|
|
from pynvml import (nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo,
|
|
nvmlInit)
|
|
|
|
from vllm import LLM
|
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|
from vllm.lora.request import LoRARequest
|
|
from vllm.model_executor.utils import set_random_seed
|
|
from vllm.multimodal import MultiModalData
|
|
from vllm.outputs import RequestOutput
|
|
from vllm.sampling_params import SamplingParams
|
|
from vllm.sequence import Logprob
|
|
from vllm.usage.usage_lib import UsageContext
|
|
from vllm.utils import Counter, random_uuid
|
|
|
|
from ...conftest import cleanup
|
|
|
|
|
|
class AsyncLLM:
|
|
"""AsyncLLM
|
|
|
|
Note: Current LLM class in vllm don't support async mode, for test purpose,
|
|
we implement async one in here. Maybe we could move to
|
|
vllm/entrypoints/llm.py in future.
|
|
|
|
Below AsyncLLM is directly borrow from vllm/entrypoints/llm.py with changes
|
|
to make to work in async mode.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: str,
|
|
tokenizer: Optional[str] = None,
|
|
tokenizer_mode: str = "auto",
|
|
skip_tokenizer_init: bool = False,
|
|
trust_remote_code: bool = False,
|
|
tensor_parallel_size: int = 1,
|
|
dtype: str = "auto",
|
|
quantization: Optional[str] = None,
|
|
revision: Optional[str] = None,
|
|
tokenizer_revision: Optional[str] = None,
|
|
seed: int = 0,
|
|
gpu_memory_utilization: float = 0.9,
|
|
swap_space: int = 4,
|
|
enforce_eager: bool = False,
|
|
max_seq_len_to_capture: int = 8192,
|
|
disable_custom_all_reduce: bool = False,
|
|
**kwargs,
|
|
) -> None:
|
|
if "disable_log_stats" not in kwargs:
|
|
kwargs["disable_log_stats"] = True
|
|
engine_args = AsyncEngineArgs(
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
tokenizer_mode=tokenizer_mode,
|
|
skip_tokenizer_init=skip_tokenizer_init,
|
|
trust_remote_code=trust_remote_code,
|
|
tensor_parallel_size=tensor_parallel_size,
|
|
dtype=dtype,
|
|
quantization=quantization,
|
|
revision=revision,
|
|
tokenizer_revision=tokenizer_revision,
|
|
seed=seed,
|
|
gpu_memory_utilization=gpu_memory_utilization,
|
|
swap_space=swap_space,
|
|
enforce_eager=enforce_eager,
|
|
max_seq_len_to_capture=max_seq_len_to_capture,
|
|
# For now use ray for the distributed back-end, since
|
|
# we rely on the use of engine_use_ray=True to avoid
|
|
# reinitializing CUDA in the same process (driver worker)
|
|
engine_use_ray=True,
|
|
distributed_executor_backend="ray",
|
|
disable_custom_all_reduce=disable_custom_all_reduce,
|
|
**kwargs,
|
|
)
|
|
self.request_counter = Counter()
|
|
self.llm_engine = AsyncLLMEngine.from_engine_args(
|
|
engine_args, usage_context=UsageContext.LLM_CLASS)
|
|
|
|
def generate(
|
|
self,
|
|
prompts: Optional[Union[str, List[str]]] = None,
|
|
sampling_params: Optional[Union[SamplingParams,
|
|
List[SamplingParams]]] = None,
|
|
prompt_token_ids: Optional[List[List[int]]] = None,
|
|
use_tqdm: bool = True,
|
|
lora_request: Optional[LoRARequest] = None,
|
|
multi_modal_data: Optional[MultiModalData] = None,
|
|
) -> List[RequestOutput]:
|
|
|
|
if prompts is None:
|
|
raise ValueError("prompts must be provided.")
|
|
if isinstance(prompts, str):
|
|
# Convert a single prompt to a list.
|
|
prompts = [prompts]
|
|
|
|
if prompts is not None:
|
|
num_requests = len(prompts)
|
|
|
|
if sampling_params is None:
|
|
# Use default sampling params.
|
|
sampling_params = SamplingParams()
|
|
|
|
elif isinstance(sampling_params,
|
|
list) and len(sampling_params) != num_requests:
|
|
raise ValueError("The lengths of prompts and "
|
|
"sampling_params must be the same.")
|
|
|
|
async def get_output(prompt, sampling_param) -> RequestOutput:
|
|
request_id = random_uuid()
|
|
results_generator = self.llm_engine.generate(
|
|
prompt, sampling_param, request_id)
|
|
final_output = None
|
|
async for request_output in results_generator:
|
|
final_output = request_output
|
|
assert final_output is not None
|
|
return final_output
|
|
|
|
outputs: List[RequestOutput] = []
|
|
try:
|
|
for i in range(num_requests):
|
|
prompt = prompts[i] if prompts is not None else None
|
|
res = asyncio.run(get_output(prompt, sampling_params))
|
|
outputs.append(res)
|
|
finally:
|
|
ray.shutdown()
|
|
return outputs
|
|
|
|
|
|
@pytest.fixture
|
|
def baseline_llm_generator(request, common_llm_kwargs,
|
|
per_test_common_llm_kwargs, baseline_llm_kwargs,
|
|
seed):
|
|
return create_llm_generator("baseline", request, common_llm_kwargs,
|
|
per_test_common_llm_kwargs,
|
|
baseline_llm_kwargs, seed)
|
|
|
|
|
|
@pytest.fixture
|
|
def test_llm_generator(request, common_llm_kwargs, per_test_common_llm_kwargs,
|
|
test_llm_kwargs, seed):
|
|
return create_llm_generator("test", request, common_llm_kwargs,
|
|
per_test_common_llm_kwargs, test_llm_kwargs,
|
|
seed)
|
|
|
|
|
|
def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
|
|
per_test_common_llm_kwargs, distinct_llm_kwargs,
|
|
seed):
|
|
kwargs = {
|
|
**common_llm_kwargs,
|
|
**per_test_common_llm_kwargs,
|
|
**distinct_llm_kwargs,
|
|
}
|
|
test_name = request.node.name
|
|
|
|
def generator_inner():
|
|
|
|
wait_for_gpu_memory_to_clear(
|
|
devices=list(range(torch.cuda.device_count())),
|
|
threshold_bytes=2 * 2**30,
|
|
timeout_s=60,
|
|
)
|
|
|
|
use_async = False
|
|
if "use_async" in kwargs:
|
|
use_async = kwargs.pop("use_async")
|
|
print(f'{use_async=}')
|
|
|
|
print(f'Creating {baseline_or_test=} LLM for {test_name=}. {kwargs=}')
|
|
llm = AsyncLLM(**kwargs) if use_async else LLM(**kwargs)
|
|
set_random_seed(seed)
|
|
|
|
yield llm
|
|
del llm
|
|
cleanup()
|
|
|
|
def generator_outer():
|
|
for llm in generator_inner():
|
|
yield llm
|
|
del llm
|
|
|
|
return generator_outer
|
|
|
|
|
|
def maybe_assert_ngram_worker(llm):
|
|
# Verify the proposer worker is ngram if ngram is specified.
|
|
if (not isinstance(llm, AsyncLLM)
|
|
and llm.llm_engine.speculative_config is not None
|
|
and llm.llm_engine.speculative_config.ngram_prompt_lookup_max > 0):
|
|
from vllm.spec_decode.ngram_worker import NGramWorker
|
|
assert isinstance(
|
|
llm.llm_engine.model_executor.driver_worker.proposer_worker,
|
|
NGramWorker)
|
|
|
|
|
|
def get_output_from_llm_generator(
|
|
llm_generator, prompts,
|
|
sampling_params) -> Tuple[List[str], List[List[int]]]:
|
|
tokens: List[str] = []
|
|
token_ids: List[List[int]] = []
|
|
for llm in llm_generator():
|
|
maybe_assert_ngram_worker(llm)
|
|
|
|
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
|
token_ids = [output.outputs[0].token_ids for output in outputs]
|
|
tokens = [output.outputs[0].text for output in outputs]
|
|
del llm
|
|
|
|
return tokens, token_ids
|
|
|
|
|
|
def get_logprobs_from_llm_generator(
|
|
llm_generator, prompts,
|
|
sampling_params) -> List[List[Dict[int, Logprob]]]:
|
|
"""Returns a dict of (token_id: Logprob) for each generated position, for
|
|
each sequence in the batch.
|
|
"""
|
|
for llm in llm_generator():
|
|
outputs = llm.generate(prompts, sampling_params, use_tqdm=True)
|
|
logprobs = [output.outputs[0].logprobs[:] for output in outputs]
|
|
del llm
|
|
|
|
return logprobs
|
|
|
|
|
|
def run_greedy_equality_correctness_test(baseline_llm_generator,
|
|
test_llm_generator,
|
|
batch_size,
|
|
max_output_len,
|
|
force_output_len: bool,
|
|
print_tokens: bool = False):
|
|
"""Helper method that compares the outputs of both the baseline LLM and
|
|
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
|
|
the same when temperature is zero.
|
|
"""
|
|
temperature = 0.0
|
|
|
|
prompts = [
|
|
"Hello, my name is",
|
|
"The president of the United States is",
|
|
"The capital of France is",
|
|
"The future of AI is",
|
|
"San Francisco is know for its",
|
|
"Facebook was created in 2004 by",
|
|
"Curious George is a",
|
|
"Python 3.11 brings improvements to its",
|
|
]
|
|
|
|
prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]
|
|
|
|
# If the test requires that we generated max_output_len tokens, then set the
|
|
# sampling params to ignore eos token.
|
|
ignore_eos = force_output_len
|
|
|
|
sampling_params = SamplingParams(
|
|
max_tokens=max_output_len,
|
|
ignore_eos=ignore_eos,
|
|
temperature=temperature,
|
|
)
|
|
|
|
spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
|
|
test_llm_generator, prompts, sampling_params)
|
|
|
|
(baseline_batch_tokens,
|
|
baseline_batch_token_ids) = get_output_from_llm_generator(
|
|
baseline_llm_generator, prompts, sampling_params)
|
|
|
|
assert len(baseline_batch_token_ids) == len(prompts)
|
|
assert len(spec_batch_token_ids) == len(prompts)
|
|
|
|
for i, (baseline_token_ids, baseline_tokens, spec_token_ids,
|
|
spec_tokens) in enumerate(
|
|
zip(baseline_batch_token_ids, baseline_batch_tokens,
|
|
spec_batch_token_ids, spec_batch_tokens)):
|
|
if print_tokens:
|
|
print(f'{i=} {baseline_tokens=}')
|
|
print(f'{i=} {spec_tokens=}')
|
|
print(f'{i=} {baseline_token_ids=}')
|
|
print(f'{i=} {spec_token_ids=}')
|
|
assert baseline_token_ids == spec_token_ids
|
|
|
|
|
|
def wait_for_gpu_memory_to_clear(devices: List[int],
|
|
threshold_bytes: int,
|
|
timeout_s: float = 120) -> None:
|
|
# Use nvml instead of pytorch to reduce measurement error from torch cuda
|
|
# context.
|
|
nvmlInit()
|
|
start_time = time.time()
|
|
while True:
|
|
output: Dict[int, str] = {}
|
|
output_raw: Dict[int, float] = {}
|
|
for device in devices:
|
|
dev_handle = nvmlDeviceGetHandleByIndex(device)
|
|
mem_info = nvmlDeviceGetMemoryInfo(dev_handle)
|
|
gb_used = mem_info.used / 2**30
|
|
output_raw[device] = gb_used
|
|
output[device] = f'{gb_used:.02f}'
|
|
|
|
print('gpu memory used (GB): ', end='')
|
|
for k, v in output.items():
|
|
print(f'{k}={v}; ', end='')
|
|
print('')
|
|
|
|
dur_s = time.time() - start_time
|
|
if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()):
|
|
print(f'Done waiting for free GPU memory on devices {devices=} '
|
|
f'({threshold_bytes/2**30=}) {dur_s=:.02f}')
|
|
break
|
|
|
|
if dur_s >= timeout_s:
|
|
raise ValueError(f'Memory of devices {devices=} not free after '
|
|
f'{dur_s=:.02f} ({threshold_bytes/2**30=})')
|
|
|
|
time.sleep(5)
|