# SPDX-License-Identifier: Apache-2.0 # ruff: noqa import asyncio import hashlib import pickle import socket from collections.abc import AsyncIterator from unittest.mock import patch import pytest import torch from vllm_test_utils.monitor import monitor from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache, MemorySnapshot, PlaceholderModule, StoreBoolean, bind_kv_cache, deprecate_kwargs, get_open_port, memory_profiling, merge_async_iterators, sha256, supports_kw, swap_dict_values) from .utils import create_new_process_for_each_test, error_on_warning @pytest.mark.asyncio async def test_merge_async_iterators(): async def mock_async_iterator(idx: int): try: while True: yield f"item from iterator {idx}" await asyncio.sleep(0.1) except asyncio.CancelledError: print(f"iterator {idx} cancelled") iterators = [mock_async_iterator(i) for i in range(3)] merged_iterator = merge_async_iterators(*iterators) async def stream_output(generator: AsyncIterator[tuple[int, str]]): async for idx, output in generator: print(f"idx: {idx}, output: {output}") task = asyncio.create_task(stream_output(merged_iterator)) await asyncio.sleep(0.5) task.cancel() with pytest.raises(asyncio.CancelledError): await task for iterator in iterators: try: # Can use anext() in python >= 3.10 await asyncio.wait_for(iterator.__anext__(), 1) except StopAsyncIteration: # All iterators should be cancelled and print this message. print("Iterator was cancelled normally") except (Exception, asyncio.CancelledError) as e: raise AssertionError() from e def test_deprecate_kwargs_always(): @deprecate_kwargs("old_arg", is_deprecated=True) def dummy(*, old_arg: object = None, new_arg: object = None): pass with pytest.warns(DeprecationWarning, match="'old_arg'"): dummy(old_arg=1) with error_on_warning(DeprecationWarning): dummy(new_arg=1) def test_deprecate_kwargs_never(): @deprecate_kwargs("old_arg", is_deprecated=False) def dummy(*, old_arg: object = None, new_arg: object = None): pass with error_on_warning(DeprecationWarning): dummy(old_arg=1) with error_on_warning(DeprecationWarning): dummy(new_arg=1) def test_deprecate_kwargs_dynamic(): is_deprecated = True @deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated) def dummy(*, old_arg: object = None, new_arg: object = None): pass with pytest.warns(DeprecationWarning, match="'old_arg'"): dummy(old_arg=1) with error_on_warning(DeprecationWarning): dummy(new_arg=1) is_deprecated = False with error_on_warning(DeprecationWarning): dummy(old_arg=1) with error_on_warning(DeprecationWarning): dummy(new_arg=1) def test_deprecate_kwargs_additional_message(): @deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd") def dummy(*, old_arg: object = None, new_arg: object = None): pass with pytest.warns(DeprecationWarning, match="abcd"): dummy(old_arg=1) def test_get_open_port(monkeypatch: pytest.MonkeyPatch): with monkeypatch.context() as m: m.setenv("VLLM_PORT", "5678") # make sure we can get multiple ports, even if the env var is set with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s1: s1.bind(("localhost", get_open_port())) with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s2: s2.bind(("localhost", get_open_port())) with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s3: s3.bind(("localhost", get_open_port())) # Tests for FlexibleArgumentParser @pytest.fixture def parser(): parser = FlexibleArgumentParser() parser.add_argument('--image-input-type', choices=['pixel_values', 'image_features']) parser.add_argument('--model-name') parser.add_argument('--batch-size', type=int) parser.add_argument('--enable-feature', action='store_true') return parser @pytest.fixture def parser_with_config(): parser = FlexibleArgumentParser() parser.add_argument('serve') parser.add_argument('model_tag', nargs='?') parser.add_argument('--model', type=str) parser.add_argument('--served-model-name', type=str) parser.add_argument('--config', type=str) parser.add_argument('--port', type=int) parser.add_argument('--tensor-parallel-size', type=int) parser.add_argument('--trust-remote-code', action='store_true') parser.add_argument('--multi-step-stream-outputs', action=StoreBoolean) return parser def test_underscore_to_dash(parser): args = parser.parse_args(['--image_input_type', 'pixel_values']) assert args.image_input_type == 'pixel_values' def test_mixed_usage(parser): args = parser.parse_args([ '--image_input_type', 'image_features', '--model-name', 'facebook/opt-125m' ]) assert args.image_input_type == 'image_features' assert args.model_name == 'facebook/opt-125m' def test_with_equals_sign(parser): args = parser.parse_args( ['--image_input_type=pixel_values', '--model-name=facebook/opt-125m']) assert args.image_input_type == 'pixel_values' assert args.model_name == 'facebook/opt-125m' def test_with_int_value(parser): args = parser.parse_args(['--batch_size', '32']) assert args.batch_size == 32 args = parser.parse_args(['--batch-size', '32']) assert args.batch_size == 32 def test_with_bool_flag(parser): args = parser.parse_args(['--enable_feature']) assert args.enable_feature is True args = parser.parse_args(['--enable-feature']) assert args.enable_feature is True def test_invalid_choice(parser): with pytest.raises(SystemExit): parser.parse_args(['--image_input_type', 'invalid_choice']) def test_missing_required_argument(parser): parser.add_argument('--required-arg', required=True) with pytest.raises(SystemExit): parser.parse_args([]) def test_cli_override_to_config(parser_with_config, cli_config_file): args = parser_with_config.parse_args([ 'serve', 'mymodel', '--config', cli_config_file, '--tensor-parallel-size', '3' ]) assert args.tensor_parallel_size == 3 args = parser_with_config.parse_args([ 'serve', 'mymodel', '--tensor-parallel-size', '3', '--config', cli_config_file ]) assert args.tensor_parallel_size == 3 assert args.port == 12312 args = parser_with_config.parse_args([ 'serve', 'mymodel', '--tensor-parallel-size', '3', '--config', cli_config_file, '--port', '666' ]) assert args.tensor_parallel_size == 3 assert args.port == 666 def test_config_args(parser_with_config, cli_config_file): args = parser_with_config.parse_args( ['serve', 'mymodel', '--config', cli_config_file]) assert args.tensor_parallel_size == 2 assert args.trust_remote_code assert not args.multi_step_stream_outputs def test_config_file(parser_with_config): with pytest.raises(FileNotFoundError): parser_with_config.parse_args( ['serve', 'mymodel', '--config', 'test_config.yml']) with pytest.raises(ValueError): parser_with_config.parse_args( ['serve', 'mymodel', '--config', './data/test_config.json']) with pytest.raises(ValueError): parser_with_config.parse_args([ 'serve', 'mymodel', '--tensor-parallel-size', '3', '--config', '--batch-size', '32' ]) def test_no_model_tag(parser_with_config, cli_config_file): with pytest.raises(ValueError): parser_with_config.parse_args(['serve', '--config', cli_config_file]) # yapf: enable @pytest.mark.parametrize( "callable,kw_name,requires_kw_only,allow_var_kwargs,is_supported", [ # Tests for positional argument support (lambda foo: None, "foo", True, True, False), (lambda foo: None, "foo", False, True, True), # Tests for positional or keyword / keyword only (lambda foo=100: None, "foo", True, True, False), (lambda *, foo: None, "foo", False, True, True), # Tests to make sure the names of variadic params are NOT supported (lambda *args: None, "args", False, True, False), (lambda **kwargs: None, "kwargs", False, True, False), # Tests for if we allow var kwargs to add support (lambda foo: None, "something_else", False, True, False), (lambda foo, **kwargs: None, "something_else", False, True, True), (lambda foo, **kwargs: None, "kwargs", True, True, False), (lambda foo, **kwargs: None, "foo", True, True, False), ]) # yapf: disable def test_supports_kw(callable,kw_name,requires_kw_only, allow_var_kwargs,is_supported): assert supports_kw( callable=callable, kw_name=kw_name, requires_kw_only=requires_kw_only, allow_var_kwargs=allow_var_kwargs ) == is_supported @create_new_process_for_each_test() def test_memory_profiling(): # Fake out some model loading + inference memory usage to test profiling # Memory used by other processes will show up as cuda usage outside of torch from vllm.distributed.device_communicators.cuda_wrapper import ( CudaRTLibrary) lib = CudaRTLibrary() # 512 MiB allocation outside of this instance handle1 = lib.cudaMalloc(512 * 1024 * 1024) baseline_snapshot = MemorySnapshot() # load weights weights = torch.randn(128, 1024, 1024, device='cuda', dtype=torch.float32) weights_memory = 128 * 1024 * 1024 * 4 # 512 MiB def measure_current_non_torch(): free, total = torch.cuda.mem_get_info() current_used = total - free current_torch = torch.cuda.memory_reserved() current_non_torch = current_used - current_torch return current_non_torch with memory_profiling(baseline_snapshot=baseline_snapshot, weights_memory=weights_memory) as result, \ monitor(measure_current_non_torch) as monitored_values: # make a memory spike, 1 GiB spike = torch.randn(256, 1024, 1024, device='cuda', dtype=torch.float32) del spike # Add some extra non-torch memory 256 MiB (simulate NCCL) handle2 = lib.cudaMalloc(256 * 1024 * 1024) # this is an analytic value, it is exact, # we only have 256 MiB non-torch memory increase measured_diff = monitored_values.values[-1] - monitored_values.values[0] assert measured_diff == 256 * 1024 * 1024 # Check that the memory usage is within 5% of the expected values # 5% tolerance is caused by cuda runtime. # we cannot control cuda runtime in the granularity of bytes, # which causes a small error (<10 MiB in practice) non_torch_ratio = result.non_torch_increase / (256 * 1024 * 1024) # noqa assert abs(non_torch_ratio - 1) <= 0.05 assert result.torch_peak_increase == 1024 * 1024 * 1024 del weights lib.cudaFree(handle1) lib.cudaFree(handle2) def test_bind_kv_cache(): from vllm.attention import Attention ctx = { 'layers.0.self_attn': Attention(32, 128, 0.1), 'layers.1.self_attn': Attention(32, 128, 0.1), 'layers.2.self_attn': Attention(32, 128, 0.1), 'layers.3.self_attn': Attention(32, 128, 0.1), } kv_cache = [ torch.zeros((1, )), torch.zeros((1, )), torch.zeros((1, )), torch.zeros((1, )), ] bind_kv_cache(ctx, [kv_cache]) assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0] assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[1] assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[2] assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[3] def test_bind_kv_cache_non_attention(): from vllm.attention import Attention # example from Jamba PP=2 ctx = { 'model.layers.20.attn': Attention(32, 128, 0.1), 'model.layers.28.attn': Attention(32, 128, 0.1), } kv_cache = [ torch.zeros((1, )), torch.zeros((1, )), ] bind_kv_cache(ctx, [kv_cache]) assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[0] assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[1] def test_bind_kv_cache_encoder_decoder(monkeypatch: pytest.MonkeyPatch): # V1 TESTS: ENCODER_DECODER is not supported on V1 yet. with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "0") from vllm.attention import Attention, AttentionType # example from bart ctx = { 'encoder.layers.0.self_attn.attn': Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER), 'decoder.layers.0.encoder_attn.attn': Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER_DECODER), 'decoder.layers.0.self_attn.attn': Attention(32, 128, 0.1, attn_type=AttentionType.DECODER), } kv_cache = [ torch.zeros((1, )), ] encoder_kv_cache = ctx['encoder.layers.0.self_attn.attn'].kv_cache bind_kv_cache(ctx, [kv_cache]) assert ctx['encoder.layers.0.self_attn.attn'].kv_cache is encoder_kv_cache assert ctx['decoder.layers.0.encoder_attn.attn'].kv_cache[0] is kv_cache[0] assert ctx['decoder.layers.0.self_attn.attn'].kv_cache[0] is kv_cache[0] def test_bind_kv_cache_pp(): with patch("vllm.utils.cuda_device_count_stateless", lambda: 2): # this test runs with 1 GPU, but we simulate 2 GPUs cfg = VllmConfig( parallel_config=ParallelConfig(pipeline_parallel_size=2)) with set_current_vllm_config(cfg): from vllm.attention import Attention ctx = { 'layers.0.self_attn': Attention(32, 128, 0.1), } kv_cache = [ [torch.zeros((1, ))], [torch.zeros((1, ))] ] bind_kv_cache(ctx, kv_cache) assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[0][0] assert ctx['layers.0.self_attn'].kv_cache[1] is kv_cache[1][0] class TestLRUCache(LRUCache): def _on_remove(self, key, value): if not hasattr(self, "_remove_counter"): self._remove_counter = 0 self._remove_counter += 1 def test_lru_cache(): cache = TestLRUCache(3) assert cache.stat() == CacheInfo(hits=0, total=0) assert cache.stat(delta=True) == CacheInfo(hits=0, total=0) cache.put(1, 1) assert len(cache) == 1 cache.put(1, 1) assert len(cache) == 1 cache.put(2, 2) assert len(cache) == 2 cache.put(3, 3) assert len(cache) == 3 assert set(cache.cache) == {1, 2, 3} cache.put(4, 4) assert len(cache) == 3 assert set(cache.cache) == {2, 3, 4} assert cache._remove_counter == 1 assert cache.get(2) == 2 assert cache.stat() == CacheInfo(hits=1, total=1) assert cache.stat(delta=True) == CacheInfo(hits=1, total=1) assert cache[2] == 2 assert cache.stat() == CacheInfo(hits=2, total=2) assert cache.stat(delta=True) == CacheInfo(hits=1, total=1) cache.put(5, 5) assert set(cache.cache) == {2, 4, 5} assert cache._remove_counter == 2 assert cache.pop(5) == 5 assert len(cache) == 2 assert set(cache.cache) == {2, 4} assert cache._remove_counter == 3 assert cache.get(-1) is None assert cache.stat() == CacheInfo(hits=2, total=3) assert cache.stat(delta=True) == CacheInfo(hits=0, total=1) cache.pop(10) assert len(cache) == 2 assert set(cache.cache) == {2, 4} assert cache._remove_counter == 3 cache.get(10) assert len(cache) == 2 assert set(cache.cache) == {2, 4} assert cache._remove_counter == 3 cache.put(6, 6) assert len(cache) == 3 assert set(cache.cache) == {2, 4, 6} assert 2 in cache assert 4 in cache assert 6 in cache cache.remove_oldest() assert len(cache) == 2 assert set(cache.cache) == {2, 6} assert cache._remove_counter == 4 cache.clear() assert len(cache) == 0 assert cache._remove_counter == 6 assert cache.stat() == CacheInfo(hits=0, total=0) assert cache.stat(delta=True) == CacheInfo(hits=0, total=0) cache._remove_counter = 0 cache[1] = 1 assert len(cache) == 1 cache[1] = 1 assert len(cache) == 1 cache[2] = 2 assert len(cache) == 2 cache[3] = 3 assert len(cache) == 3 assert set(cache.cache) == {1, 2, 3} cache[4] = 4 assert len(cache) == 3 assert set(cache.cache) == {2, 3, 4} assert cache._remove_counter == 1 assert cache[2] == 2 cache[5] = 5 assert set(cache.cache) == {2, 4, 5} assert cache._remove_counter == 2 del cache[5] assert len(cache) == 2 assert set(cache.cache) == {2, 4} assert cache._remove_counter == 3 cache.pop(10) assert len(cache) == 2 assert set(cache.cache) == {2, 4} assert cache._remove_counter == 3 cache[6] = 6 assert len(cache) == 3 assert set(cache.cache) == {2, 4, 6} assert 2 in cache assert 4 in cache assert 6 in cache def test_placeholder_module_error_handling(): placeholder = PlaceholderModule("placeholder_1234") def build_ctx(): return pytest.raises(ModuleNotFoundError, match="No module named") with build_ctx(): int(placeholder) with build_ctx(): placeholder() with build_ctx(): _ = placeholder.some_attr with build_ctx(): # Test conflict with internal __name attribute _ = placeholder.name # OK to print the placeholder or use it in a f-string _ = repr(placeholder) _ = str(placeholder) # No error yet; only error when it is used downstream placeholder_attr = placeholder.placeholder_attr("attr") with build_ctx(): int(placeholder_attr) with build_ctx(): placeholder_attr() with build_ctx(): _ = placeholder_attr.some_attr with build_ctx(): # Test conflict with internal __module attribute _ = placeholder_attr.module @pytest.mark.parametrize( "obj,key1,key2", [ # Tests for both keys exist ({1: "a", 2: "b"}, 1, 2), # Tests for one key does not exist ({1: "a", 2: "b"}, 1, 3), # Tests for both keys do not exist ({1: "a", 2: "b"}, 3, 4), ]) def test_swap_dict_values(obj, key1, key2): original_obj = obj.copy() swap_dict_values(obj, key1, key2) if key1 in original_obj: assert obj[key2] == original_obj[key1] else: assert key2 not in obj if key2 in original_obj: assert obj[key1] == original_obj[key2] else: assert key1 not in obj def test_model_specification(parser_with_config, cli_config_file, cli_config_file_with_model): # Test model in CLI takes precedence over config args = parser_with_config.parse_args([ 'serve', 'cli-model', '--config', cli_config_file_with_model ]) assert args.model_tag == 'cli-model' assert args.served_model_name == 'mymodel' # Test model from config file works args = parser_with_config.parse_args([ 'serve', '--config', cli_config_file_with_model, ]) assert args.model == 'config-model' assert args.served_model_name == 'mymodel' # Test no model specified anywhere raises error with pytest.raises(ValueError, match="No model specified!"): parser_with_config.parse_args(['serve', '--config', cli_config_file]) # Test using --model option raises error with pytest.raises( ValueError, match=( "With `vllm serve`, you should provide the model as a positional " "argument or in a config file instead of via the `--model` option." ), ): parser_with_config.parse_args(['serve', '--model', 'my-model']) # Test other config values are preserved args = parser_with_config.parse_args([ 'serve', 'cli-model', '--config', cli_config_file_with_model, ]) assert args.tensor_parallel_size == 2 assert args.trust_remote_code is True assert args.multi_step_stream_outputs is False assert args.port == 12312 @pytest.mark.parametrize("input", [(), ("abc", ), (None, ), (None, bool, [1, 2, 3])]) @pytest.mark.parametrize("output", [0, 1, 2]) def test_sha256(input: tuple, output: int): hash = sha256(input) assert hash is not None assert isinstance(hash, int) assert hash != 0 bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL) assert hash == int.from_bytes(hashlib.sha256(bytes).digest(), byteorder="big") # hashing again, returns the same value assert hash == sha256(input) # hashing different input, returns different value assert hash != sha256(input + (1, ))