2024-05-30 07:02:25 +08:00
|
|
|
import asyncio
|
2024-06-06 22:15:11 -07:00
|
|
|
import os
|
|
|
|
import socket
|
2024-08-13 09:20:20 +08:00
|
|
|
from typing import AsyncIterator, Tuple
|
2024-05-30 07:02:25 +08:00
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
import pytest
|
2024-12-16 13:32:25 -08:00
|
|
|
import torch
|
2025-01-08 14:36:03 +08:00
|
|
|
from vllm_test_utils import monitor
|
2024-05-29 04:29:31 +08:00
|
|
|
|
2024-10-27 10:46:41 -07:00
|
|
|
from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs,
|
2024-12-16 13:32:25 -08:00
|
|
|
get_open_port, memory_profiling, merge_async_iterators,
|
|
|
|
supports_kw)
|
2024-05-29 04:29:31 +08:00
|
|
|
|
2024-12-16 13:32:25 -08:00
|
|
|
from .utils import error_on_warning, fork_new_process_for_each_test
|
2024-05-29 04:29:31 +08:00
|
|
|
|
2024-05-30 07:02:25 +08:00
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_merge_async_iterators():
|
|
|
|
|
2024-08-13 09:20:20 +08:00
|
|
|
async def mock_async_iterator(idx: int):
|
2024-05-30 07:02:25 +08:00
|
|
|
try:
|
|
|
|
while True:
|
|
|
|
yield f"item from iterator {idx}"
|
|
|
|
await asyncio.sleep(0.1)
|
|
|
|
except asyncio.CancelledError:
|
2024-08-06 22:21:41 -07:00
|
|
|
print(f"iterator {idx} cancelled")
|
2024-05-30 07:02:25 +08:00
|
|
|
|
|
|
|
iterators = [mock_async_iterator(i) for i in range(3)]
|
2024-12-17 13:26:32 -07:00
|
|
|
merged_iterator = merge_async_iterators(*iterators)
|
2024-05-30 07:02:25 +08:00
|
|
|
|
|
|
|
async def stream_output(generator: AsyncIterator[Tuple[int, str]]):
|
|
|
|
async for idx, output in generator:
|
|
|
|
print(f"idx: {idx}, output: {output}")
|
|
|
|
|
|
|
|
task = asyncio.create_task(stream_output(merged_iterator))
|
|
|
|
await asyncio.sleep(0.5)
|
|
|
|
task.cancel()
|
|
|
|
with pytest.raises(asyncio.CancelledError):
|
|
|
|
await task
|
|
|
|
|
|
|
|
for iterator in iterators:
|
|
|
|
try:
|
2024-08-13 09:20:20 +08:00
|
|
|
# Can use anext() in python >= 3.10
|
|
|
|
await asyncio.wait_for(iterator.__anext__(), 1)
|
2024-05-30 07:02:25 +08:00
|
|
|
except StopAsyncIteration:
|
|
|
|
# All iterators should be cancelled and print this message.
|
|
|
|
print("Iterator was cancelled normally")
|
|
|
|
except (Exception, asyncio.CancelledError) as e:
|
|
|
|
raise AssertionError() from e
|
|
|
|
|
2024-05-29 04:29:31 +08:00
|
|
|
|
|
|
|
def test_deprecate_kwargs_always():
|
|
|
|
|
|
|
|
@deprecate_kwargs("old_arg", is_deprecated=True)
|
|
|
|
def dummy(*, old_arg: object = None, new_arg: object = None):
|
|
|
|
pass
|
|
|
|
|
|
|
|
with pytest.warns(DeprecationWarning, match="'old_arg'"):
|
|
|
|
dummy(old_arg=1)
|
|
|
|
|
2024-10-19 02:31:58 +08:00
|
|
|
with error_on_warning(DeprecationWarning):
|
2024-05-29 04:29:31 +08:00
|
|
|
dummy(new_arg=1)
|
|
|
|
|
|
|
|
|
|
|
|
def test_deprecate_kwargs_never():
|
|
|
|
|
|
|
|
@deprecate_kwargs("old_arg", is_deprecated=False)
|
|
|
|
def dummy(*, old_arg: object = None, new_arg: object = None):
|
|
|
|
pass
|
|
|
|
|
2024-10-19 02:31:58 +08:00
|
|
|
with error_on_warning(DeprecationWarning):
|
2024-05-29 04:29:31 +08:00
|
|
|
dummy(old_arg=1)
|
|
|
|
|
2024-10-19 02:31:58 +08:00
|
|
|
with error_on_warning(DeprecationWarning):
|
2024-05-29 04:29:31 +08:00
|
|
|
dummy(new_arg=1)
|
|
|
|
|
|
|
|
|
|
|
|
def test_deprecate_kwargs_dynamic():
|
|
|
|
is_deprecated = True
|
|
|
|
|
|
|
|
@deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated)
|
|
|
|
def dummy(*, old_arg: object = None, new_arg: object = None):
|
|
|
|
pass
|
|
|
|
|
|
|
|
with pytest.warns(DeprecationWarning, match="'old_arg'"):
|
|
|
|
dummy(old_arg=1)
|
|
|
|
|
2024-10-19 02:31:58 +08:00
|
|
|
with error_on_warning(DeprecationWarning):
|
2024-05-29 04:29:31 +08:00
|
|
|
dummy(new_arg=1)
|
|
|
|
|
|
|
|
is_deprecated = False
|
|
|
|
|
2024-10-19 02:31:58 +08:00
|
|
|
with error_on_warning(DeprecationWarning):
|
2024-05-29 04:29:31 +08:00
|
|
|
dummy(old_arg=1)
|
|
|
|
|
2024-10-19 02:31:58 +08:00
|
|
|
with error_on_warning(DeprecationWarning):
|
2024-05-29 04:29:31 +08:00
|
|
|
dummy(new_arg=1)
|
|
|
|
|
|
|
|
|
|
|
|
def test_deprecate_kwargs_additional_message():
|
|
|
|
|
|
|
|
@deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd")
|
|
|
|
def dummy(*, old_arg: object = None, new_arg: object = None):
|
|
|
|
pass
|
|
|
|
|
|
|
|
with pytest.warns(DeprecationWarning, match="abcd"):
|
|
|
|
dummy(old_arg=1)
|
2024-06-06 22:15:11 -07:00
|
|
|
|
|
|
|
|
|
|
|
def test_get_open_port():
|
|
|
|
os.environ["VLLM_PORT"] = "5678"
|
|
|
|
# make sure we can get multiple ports, even if the env var is set
|
|
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s1:
|
|
|
|
s1.bind(("localhost", get_open_port()))
|
|
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s2:
|
|
|
|
s2.bind(("localhost", get_open_port()))
|
|
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s3:
|
|
|
|
s3.bind(("localhost", get_open_port()))
|
|
|
|
os.environ.pop("VLLM_PORT")
|
2024-06-25 15:18:03 -04:00
|
|
|
|
|
|
|
|
|
|
|
# Tests for FlexibleArgumentParser
|
|
|
|
@pytest.fixture
|
|
|
|
def parser():
|
|
|
|
parser = FlexibleArgumentParser()
|
|
|
|
parser.add_argument('--image-input-type',
|
|
|
|
choices=['pixel_values', 'image_features'])
|
|
|
|
parser.add_argument('--model-name')
|
|
|
|
parser.add_argument('--batch-size', type=int)
|
|
|
|
parser.add_argument('--enable-feature', action='store_true')
|
|
|
|
return parser
|
|
|
|
|
|
|
|
|
2024-08-30 08:21:02 -07:00
|
|
|
@pytest.fixture
|
|
|
|
def parser_with_config():
|
|
|
|
parser = FlexibleArgumentParser()
|
|
|
|
parser.add_argument('serve')
|
2024-10-05 10:35:11 -07:00
|
|
|
parser.add_argument('model_tag')
|
|
|
|
parser.add_argument('--served-model-name', type=str)
|
2024-08-30 08:21:02 -07:00
|
|
|
parser.add_argument('--config', type=str)
|
|
|
|
parser.add_argument('--port', type=int)
|
|
|
|
parser.add_argument('--tensor-parallel-size', type=int)
|
2024-10-27 10:46:41 -07:00
|
|
|
parser.add_argument('--trust-remote-code', action='store_true')
|
|
|
|
parser.add_argument('--multi-step-stream-outputs', action=StoreBoolean)
|
2024-08-30 08:21:02 -07:00
|
|
|
return parser
|
|
|
|
|
|
|
|
|
2024-06-25 15:18:03 -04:00
|
|
|
def test_underscore_to_dash(parser):
|
|
|
|
args = parser.parse_args(['--image_input_type', 'pixel_values'])
|
|
|
|
assert args.image_input_type == 'pixel_values'
|
|
|
|
|
|
|
|
|
|
|
|
def test_mixed_usage(parser):
|
|
|
|
args = parser.parse_args([
|
|
|
|
'--image_input_type', 'image_features', '--model-name',
|
|
|
|
'facebook/opt-125m'
|
|
|
|
])
|
|
|
|
assert args.image_input_type == 'image_features'
|
|
|
|
assert args.model_name == 'facebook/opt-125m'
|
|
|
|
|
|
|
|
|
|
|
|
def test_with_equals_sign(parser):
|
|
|
|
args = parser.parse_args(
|
|
|
|
['--image_input_type=pixel_values', '--model-name=facebook/opt-125m'])
|
|
|
|
assert args.image_input_type == 'pixel_values'
|
|
|
|
assert args.model_name == 'facebook/opt-125m'
|
|
|
|
|
|
|
|
|
|
|
|
def test_with_int_value(parser):
|
|
|
|
args = parser.parse_args(['--batch_size', '32'])
|
|
|
|
assert args.batch_size == 32
|
|
|
|
args = parser.parse_args(['--batch-size', '32'])
|
|
|
|
assert args.batch_size == 32
|
|
|
|
|
|
|
|
|
|
|
|
def test_with_bool_flag(parser):
|
|
|
|
args = parser.parse_args(['--enable_feature'])
|
|
|
|
assert args.enable_feature is True
|
|
|
|
args = parser.parse_args(['--enable-feature'])
|
|
|
|
assert args.enable_feature is True
|
|
|
|
|
|
|
|
|
|
|
|
def test_invalid_choice(parser):
|
|
|
|
with pytest.raises(SystemExit):
|
|
|
|
parser.parse_args(['--image_input_type', 'invalid_choice'])
|
|
|
|
|
|
|
|
|
|
|
|
def test_missing_required_argument(parser):
|
|
|
|
parser.add_argument('--required-arg', required=True)
|
|
|
|
with pytest.raises(SystemExit):
|
|
|
|
parser.parse_args([])
|
2024-08-30 08:21:02 -07:00
|
|
|
|
|
|
|
|
|
|
|
def test_cli_override_to_config(parser_with_config):
|
|
|
|
args = parser_with_config.parse_args([
|
2024-10-05 10:35:11 -07:00
|
|
|
'serve', 'mymodel', '--config', './data/test_config.yaml',
|
2024-08-30 08:21:02 -07:00
|
|
|
'--tensor-parallel-size', '3'
|
|
|
|
])
|
|
|
|
assert args.tensor_parallel_size == 3
|
|
|
|
args = parser_with_config.parse_args([
|
2024-10-05 10:35:11 -07:00
|
|
|
'serve', 'mymodel', '--tensor-parallel-size', '3', '--config',
|
2024-08-30 08:21:02 -07:00
|
|
|
'./data/test_config.yaml'
|
|
|
|
])
|
|
|
|
assert args.tensor_parallel_size == 3
|
2024-10-05 10:35:11 -07:00
|
|
|
assert args.port == 12312
|
|
|
|
args = parser_with_config.parse_args([
|
|
|
|
'serve', 'mymodel', '--tensor-parallel-size', '3', '--config',
|
|
|
|
'./data/test_config.yaml', '--port', '666'
|
|
|
|
])
|
|
|
|
assert args.tensor_parallel_size == 3
|
|
|
|
assert args.port == 666
|
2024-08-30 08:21:02 -07:00
|
|
|
|
|
|
|
|
|
|
|
def test_config_args(parser_with_config):
|
|
|
|
args = parser_with_config.parse_args(
|
2024-10-05 10:35:11 -07:00
|
|
|
['serve', 'mymodel', '--config', './data/test_config.yaml'])
|
2024-08-30 08:21:02 -07:00
|
|
|
assert args.tensor_parallel_size == 2
|
2024-10-27 10:46:41 -07:00
|
|
|
assert args.trust_remote_code
|
|
|
|
assert not args.multi_step_stream_outputs
|
2024-08-30 08:21:02 -07:00
|
|
|
|
|
|
|
|
|
|
|
def test_config_file(parser_with_config):
|
|
|
|
with pytest.raises(FileNotFoundError):
|
2024-10-05 10:35:11 -07:00
|
|
|
parser_with_config.parse_args(
|
|
|
|
['serve', 'mymodel', '--config', 'test_config.yml'])
|
2024-08-30 08:21:02 -07:00
|
|
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
parser_with_config.parse_args(
|
2024-10-05 10:35:11 -07:00
|
|
|
['serve', 'mymodel', '--config', './data/test_config.json'])
|
2024-08-30 08:21:02 -07:00
|
|
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
parser_with_config.parse_args([
|
2024-10-05 10:35:11 -07:00
|
|
|
'serve', 'mymodel', '--tensor-parallel-size', '3', '--config',
|
|
|
|
'--batch-size', '32'
|
2024-08-30 08:21:02 -07:00
|
|
|
])
|
2024-10-05 10:35:11 -07:00
|
|
|
|
|
|
|
|
|
|
|
def test_no_model_tag(parser_with_config):
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
parser_with_config.parse_args(
|
|
|
|
['serve', '--config', './data/test_config.yaml'])
|
2024-10-08 08:12:56 -06:00
|
|
|
|
|
|
|
|
|
|
|
# yapf: enable
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported",
|
|
|
|
[
|
|
|
|
# Tests for positional argument support
|
|
|
|
(lambda foo: None, "foo", True, True, False),
|
|
|
|
(lambda foo: None, "foo", False, True, True),
|
|
|
|
# Tests for positional or keyword / keyword only
|
|
|
|
(lambda foo=100: None, "foo", True, True, False),
|
|
|
|
(lambda *, foo: None, "foo", False, True, True),
|
|
|
|
# Tests to make sure the names of variadic params are NOT supported
|
|
|
|
(lambda *args: None, "args", False, True, False),
|
|
|
|
(lambda **kwargs: None, "kwargs", False, True, False),
|
|
|
|
# Tests for if we allow var kwargs to add support
|
|
|
|
(lambda foo: None, "something_else", False, True, False),
|
|
|
|
(lambda foo, **kwargs: None, "something_else", False, True, True),
|
|
|
|
(lambda foo, **kwargs: None, "kwargs", True, True, False),
|
|
|
|
(lambda foo, **kwargs: None, "foo", True, True, False),
|
|
|
|
])
|
|
|
|
# yapf: disable
|
|
|
|
def test_supports_kw(callable,kw_name,requires_kw_only,
|
|
|
|
allow_var_kwargs,is_supported):
|
|
|
|
assert supports_kw(
|
|
|
|
callable=callable,
|
|
|
|
kw_name=kw_name,
|
|
|
|
requires_kw_only=requires_kw_only,
|
|
|
|
allow_var_kwargs=allow_var_kwargs
|
|
|
|
) == is_supported
|
2024-12-16 13:32:25 -08:00
|
|
|
|
|
|
|
|
|
|
|
@fork_new_process_for_each_test
|
|
|
|
def test_memory_profiling():
|
|
|
|
# Fake out some model loading + inference memory usage to test profiling
|
|
|
|
# Memory used by other processes will show up as cuda usage outside of torch
|
|
|
|
from vllm.distributed.device_communicators.cuda_wrapper import (
|
|
|
|
CudaRTLibrary)
|
|
|
|
lib = CudaRTLibrary()
|
|
|
|
# 512 MiB allocation outside of this instance
|
|
|
|
handle1 = lib.cudaMalloc(512 * 1024 * 1024)
|
|
|
|
|
|
|
|
baseline_memory_in_bytes = \
|
|
|
|
torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]
|
|
|
|
|
|
|
|
# load weights
|
|
|
|
|
|
|
|
weights = torch.randn(128, 1024, 1024, device='cuda', dtype=torch.float32)
|
|
|
|
|
|
|
|
weights_memory_in_bytes = 128 * 1024 * 1024 * 4 # 512 MiB
|
|
|
|
|
2025-01-08 14:36:03 +08:00
|
|
|
def measure_current_non_torch():
|
|
|
|
free, total = torch.cuda.mem_get_info()
|
|
|
|
current_used = total - free
|
|
|
|
current_torch = torch.cuda.memory_reserved()
|
|
|
|
current_non_torch = current_used - current_torch
|
|
|
|
return current_non_torch
|
|
|
|
|
2024-12-16 13:32:25 -08:00
|
|
|
with memory_profiling(baseline_memory_in_bytes=baseline_memory_in_bytes,
|
2025-01-08 14:36:03 +08:00
|
|
|
weights_memory_in_bytes=weights_memory_in_bytes) as result, \
|
|
|
|
monitor(measure_current_non_torch) as monitored_values:
|
2024-12-16 13:32:25 -08:00
|
|
|
# make a memory spike, 1 GiB
|
|
|
|
spike = torch.randn(256, 1024, 1024, device='cuda', dtype=torch.float32)
|
|
|
|
del spike
|
|
|
|
|
|
|
|
# Add some extra non-torch memory 256 MiB (simulate NCCL)
|
|
|
|
handle2 = lib.cudaMalloc(256 * 1024 * 1024)
|
|
|
|
|
2025-01-08 14:36:03 +08:00
|
|
|
# this is an analytic value, it is exact,
|
|
|
|
# we only have 256 MiB non-torch memory increase
|
|
|
|
measured_diff = monitored_values.values[-1] - monitored_values.values[0]
|
|
|
|
assert measured_diff == 256 * 1024 * 1024
|
|
|
|
|
2024-12-16 13:32:25 -08:00
|
|
|
# Check that the memory usage is within 5% of the expected values
|
2025-01-08 14:36:03 +08:00
|
|
|
# 5% tolerance is caused by PyTorch caching allocator,
|
|
|
|
# we cannot control PyTorch's behavior of its internal buffers,
|
|
|
|
# which causes a small error (<10 MiB in practice)
|
2024-12-16 13:32:25 -08:00
|
|
|
non_torch_ratio = result.non_torch_increase_in_bytes / (256 * 1024 * 1024) # noqa
|
|
|
|
torch_peak_ratio = result.torch_peak_increase_in_bytes / (1024 * 1024 * 1024) # noqa
|
|
|
|
assert abs(non_torch_ratio - 1) <= 0.05
|
|
|
|
assert abs(torch_peak_ratio - 1) <= 0.05
|
|
|
|
del weights
|
|
|
|
lib.cudaFree(handle1)
|
|
|
|
lib.cudaFree(handle2)
|