diff --git a/tests/lora/test_utils.py b/tests/lora/test_utils.py index 34a26e9e..1c90cedf 100644 --- a/tests/lora/test_utils.py +++ b/tests/lora/test_utils.py @@ -9,7 +9,6 @@ from torch import nn from vllm.lora.utils import (get_adapter_absolute_path, parse_fine_tuned_lora_name, replace_submodule) -from vllm.utils import LRUCache def test_parse_fine_tuned_lora_name_valid(): @@ -85,114 +84,6 @@ def test_replace_submodule(): assert dict(model.named_modules())["seq1.dense2"] == dense2 -class TestLRUCache(LRUCache): - - def _on_remove(self, key, value): - if not hasattr(self, "_remove_counter"): - self._remove_counter = 0 - self._remove_counter += 1 - - -def test_lru_cache(): - cache = TestLRUCache(3) - - cache.put(1, 1) - assert len(cache) == 1 - - cache.put(1, 1) - assert len(cache) == 1 - - cache.put(2, 2) - assert len(cache) == 2 - - cache.put(3, 3) - assert len(cache) == 3 - assert set(cache.cache) == {1, 2, 3} - - cache.put(4, 4) - assert len(cache) == 3 - assert set(cache.cache) == {2, 3, 4} - assert cache._remove_counter == 1 - assert cache.get(2) == 2 - - cache.put(5, 5) - assert set(cache.cache) == {2, 4, 5} - assert cache._remove_counter == 2 - - assert cache.pop(5) == 5 - assert len(cache) == 2 - assert set(cache.cache) == {2, 4} - assert cache._remove_counter == 3 - - cache.pop(10) - assert len(cache) == 2 - assert set(cache.cache) == {2, 4} - assert cache._remove_counter == 3 - - cache.get(10) - assert len(cache) == 2 - assert set(cache.cache) == {2, 4} - assert cache._remove_counter == 3 - - cache.put(6, 6) - assert len(cache) == 3 - assert set(cache.cache) == {2, 4, 6} - assert 2 in cache - assert 4 in cache - assert 6 in cache - - cache.remove_oldest() - assert len(cache) == 2 - assert set(cache.cache) == {2, 6} - assert cache._remove_counter == 4 - - cache.clear() - assert len(cache) == 0 - assert cache._remove_counter == 6 - - cache._remove_counter = 0 - - cache[1] = 1 - assert len(cache) == 1 - - cache[1] = 1 - assert len(cache) == 1 - - cache[2] = 2 - assert len(cache) == 2 - - cache[3] = 3 - assert len(cache) == 3 - assert set(cache.cache) == {1, 2, 3} - - cache[4] = 4 - assert len(cache) == 3 - assert set(cache.cache) == {2, 3, 4} - assert cache._remove_counter == 1 - assert cache[2] == 2 - - cache[5] = 5 - assert set(cache.cache) == {2, 4, 5} - assert cache._remove_counter == 2 - - del cache[5] - assert len(cache) == 2 - assert set(cache.cache) == {2, 4} - assert cache._remove_counter == 3 - - cache.pop(10) - assert len(cache) == 2 - assert set(cache.cache) == {2, 4} - assert cache._remove_counter == 3 - - cache[6] = 6 - assert len(cache) == 3 - assert set(cache.cache) == {2, 4, 6} - assert 2 in cache - assert 4 in cache - assert 6 in cache - - # Unit tests for get_adapter_absolute_path @patch('os.path.isabs') def test_get_adapter_absolute_path_absolute(mock_isabs): diff --git a/tests/test_utils.py b/tests/test_utils.py index b6129a10..580e65f1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -13,11 +13,11 @@ import torch from vllm_test_utils.monitor import monitor from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config -from vllm.utils import (FlexibleArgumentParser, MemorySnapshot, - PlaceholderModule, StoreBoolean, bind_kv_cache, - deprecate_kwargs, get_open_port, memory_profiling, - merge_async_iterators, sha256, supports_kw, - swap_dict_values) +from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache, + MemorySnapshot, PlaceholderModule, StoreBoolean, + bind_kv_cache, deprecate_kwargs, get_open_port, + memory_profiling, merge_async_iterators, sha256, + supports_kw, swap_dict_values) from .utils import create_new_process_for_each_test, error_on_warning @@ -417,6 +417,129 @@ def test_bind_kv_cache_pp(): assert ctx['layers.0.self_attn'].kv_cache[1] is kv_cache[1][0] +class TestLRUCache(LRUCache): + + def _on_remove(self, key, value): + if not hasattr(self, "_remove_counter"): + self._remove_counter = 0 + self._remove_counter += 1 + + +def test_lru_cache(): + cache = TestLRUCache(3) + assert cache.stat() == CacheInfo(hits=0, total=0) + assert cache.stat(delta=True) == CacheInfo(hits=0, total=0) + + cache.put(1, 1) + assert len(cache) == 1 + + cache.put(1, 1) + assert len(cache) == 1 + + cache.put(2, 2) + assert len(cache) == 2 + + cache.put(3, 3) + assert len(cache) == 3 + assert set(cache.cache) == {1, 2, 3} + + cache.put(4, 4) + assert len(cache) == 3 + assert set(cache.cache) == {2, 3, 4} + assert cache._remove_counter == 1 + + assert cache.get(2) == 2 + assert cache.stat() == CacheInfo(hits=1, total=1) + assert cache.stat(delta=True) == CacheInfo(hits=1, total=1) + + assert cache[2] == 2 + assert cache.stat() == CacheInfo(hits=2, total=2) + assert cache.stat(delta=True) == CacheInfo(hits=1, total=1) + + cache.put(5, 5) + assert set(cache.cache) == {2, 4, 5} + assert cache._remove_counter == 2 + + assert cache.pop(5) == 5 + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + assert cache.get(-1) is None + assert cache.stat() == CacheInfo(hits=2, total=3) + assert cache.stat(delta=True) == CacheInfo(hits=0, total=1) + + cache.pop(10) + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.get(10) + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.put(6, 6) + assert len(cache) == 3 + assert set(cache.cache) == {2, 4, 6} + assert 2 in cache + assert 4 in cache + assert 6 in cache + + cache.remove_oldest() + assert len(cache) == 2 + assert set(cache.cache) == {2, 6} + assert cache._remove_counter == 4 + + cache.clear() + assert len(cache) == 0 + assert cache._remove_counter == 6 + assert cache.stat() == CacheInfo(hits=0, total=0) + assert cache.stat(delta=True) == CacheInfo(hits=0, total=0) + + cache._remove_counter = 0 + + cache[1] = 1 + assert len(cache) == 1 + + cache[1] = 1 + assert len(cache) == 1 + + cache[2] = 2 + assert len(cache) == 2 + + cache[3] = 3 + assert len(cache) == 3 + assert set(cache.cache) == {1, 2, 3} + + cache[4] = 4 + assert len(cache) == 3 + assert set(cache.cache) == {2, 3, 4} + assert cache._remove_counter == 1 + assert cache[2] == 2 + + cache[5] = 5 + assert set(cache.cache) == {2, 4, 5} + assert cache._remove_counter == 2 + + del cache[5] + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.pop(10) + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache[6] = 6 + assert len(cache) == 3 + assert set(cache.cache) == {2, 4, 6} + assert 2 in cache + assert 4 in cache + assert 6 in cache + + def test_placeholder_module_error_handling(): placeholder = PlaceholderModule("placeholder_1234") diff --git a/vllm/utils.py b/vllm/utils.py index c2aad049..c4c74a37 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -236,6 +236,12 @@ class CacheInfo(NamedTuple): return self.hits / self.total + def __sub__(self, other: CacheInfo): + return CacheInfo( + hits=self.hits - other.hits, + total=self.total - other.total, + ) + class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): @@ -243,15 +249,26 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): capacity: float, getsizeof: Optional[Callable[[_V], float]] = None): super().__init__(capacity, getsizeof) + self.pinned_items = set[_K]() - self.capacity = capacity self._hits = 0 self._total = 0 + self._last_info = CacheInfo(hits=0, total=0) + + def __getitem__(self, key: _K, *, update_info: bool = True) -> _V: + value = super().__getitem__(key) + + if update_info: + self._hits += 1 + self._total += 1 + + return value def __delitem__(self, key: _K) -> None: run_on_remove = key in self - value = self.__getitem__(key) + value = self.__getitem__(key, + update_info=False) # type: ignore[call-arg] super().__delitem__(key) if key in self.pinned_items: # Todo: add warning to inform that del pinned item @@ -271,8 +288,32 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): """Return the internal order dictionary (read-only).""" return MappingProxyType(self._LRUCache__order) # type: ignore - def stat(self) -> CacheInfo: - return CacheInfo(hits=self._hits, total=self._total) + @property + def capacity(self) -> float: + return self.maxsize + + @property + def usage(self) -> float: + if self.maxsize == 0: + return 0 + + return self.currsize / self.maxsize + + def stat(self, *, delta: bool = False) -> CacheInfo: + """ + Gets the cumulative number of hits and queries against this cache. + + If :code:`delta=True`, instead gets these statistics + since the last call that also passed :code:`delta=True`. + """ + info = CacheInfo(hits=self._hits, total=self._total) + + if delta: + info_delta = info - self._last_info + self._last_info = info + info = info_delta + + return info def touch(self, key: _K) -> None: self._LRUCache__update(key) # type: ignore @@ -292,7 +333,8 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): _T]] = None) -> Optional[Union[_V, _T]]: value: Optional[Union[_V, _T]] if key in self: - value = self.__getitem__(key) + value = self.__getitem__( + key, update_info=False) # type: ignore[call-arg] self._hits += 1 else: @@ -317,8 +359,9 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): if key not in self: return default - value = self[key] - del self[key] + value = self.__getitem__(key, + update_info=False) # type: ignore[call-arg] + self.__delitem__(key) return value def put(self, key: _K, value: _V) -> None: @@ -353,10 +396,6 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): while self.currsize > self.capacity: self.remove_oldest() - def clear(self) -> None: - while len(self) > 0: - self.remove_oldest(remove_pinned=True) - def popitem(self, remove_pinned: bool = False): """Remove and return the `(key, value)` pair least recently used.""" if not remove_pinned: @@ -372,6 +411,14 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]): value = self.pop(cast(_K, lru_key)) return (lru_key, value) + def clear(self) -> None: + while len(self) > 0: + self.remove_oldest(remove_pinned=True) + + self._hits = 0 + self._total = 0 + self._last_info = CacheInfo(hits=0, total=0) + class PyObjectCache: """Used to cache python objects to avoid object allocations diff --git a/vllm/v1/engine/mm_input_cache.py b/vllm/v1/engine/mm_input_cache.py index ef5a2e5a..c765c1bb 100644 --- a/vllm/v1/engine/mm_input_cache.py +++ b/vllm/v1/engine/mm_input_cache.py @@ -50,7 +50,7 @@ class MirroredProcessingCache: full_mm_inputs = list[Optional[MultiModalKwargs]]() for mm_input, mm_hash in zip(mm_inputs, mm_hashes): - if mm_hash in self.mm_cache: + if self.mm_cache.get(mm_hash) is not None: mm_input = None else: self.mm_cache[mm_hash] = mm_input