[Bugfix] Multi-modal caches not acting like LRU caches (#16593)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
6bf27affb6
commit
aa29841ede
@ -9,7 +9,6 @@ from torch import nn
|
|||||||
|
|
||||||
from vllm.lora.utils import (get_adapter_absolute_path,
|
from vllm.lora.utils import (get_adapter_absolute_path,
|
||||||
parse_fine_tuned_lora_name, replace_submodule)
|
parse_fine_tuned_lora_name, replace_submodule)
|
||||||
from vllm.utils import LRUCache
|
|
||||||
|
|
||||||
|
|
||||||
def test_parse_fine_tuned_lora_name_valid():
|
def test_parse_fine_tuned_lora_name_valid():
|
||||||
@ -85,114 +84,6 @@ def test_replace_submodule():
|
|||||||
assert dict(model.named_modules())["seq1.dense2"] == dense2
|
assert dict(model.named_modules())["seq1.dense2"] == dense2
|
||||||
|
|
||||||
|
|
||||||
class TestLRUCache(LRUCache):
|
|
||||||
|
|
||||||
def _on_remove(self, key, value):
|
|
||||||
if not hasattr(self, "_remove_counter"):
|
|
||||||
self._remove_counter = 0
|
|
||||||
self._remove_counter += 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_lru_cache():
|
|
||||||
cache = TestLRUCache(3)
|
|
||||||
|
|
||||||
cache.put(1, 1)
|
|
||||||
assert len(cache) == 1
|
|
||||||
|
|
||||||
cache.put(1, 1)
|
|
||||||
assert len(cache) == 1
|
|
||||||
|
|
||||||
cache.put(2, 2)
|
|
||||||
assert len(cache) == 2
|
|
||||||
|
|
||||||
cache.put(3, 3)
|
|
||||||
assert len(cache) == 3
|
|
||||||
assert set(cache.cache) == {1, 2, 3}
|
|
||||||
|
|
||||||
cache.put(4, 4)
|
|
||||||
assert len(cache) == 3
|
|
||||||
assert set(cache.cache) == {2, 3, 4}
|
|
||||||
assert cache._remove_counter == 1
|
|
||||||
assert cache.get(2) == 2
|
|
||||||
|
|
||||||
cache.put(5, 5)
|
|
||||||
assert set(cache.cache) == {2, 4, 5}
|
|
||||||
assert cache._remove_counter == 2
|
|
||||||
|
|
||||||
assert cache.pop(5) == 5
|
|
||||||
assert len(cache) == 2
|
|
||||||
assert set(cache.cache) == {2, 4}
|
|
||||||
assert cache._remove_counter == 3
|
|
||||||
|
|
||||||
cache.pop(10)
|
|
||||||
assert len(cache) == 2
|
|
||||||
assert set(cache.cache) == {2, 4}
|
|
||||||
assert cache._remove_counter == 3
|
|
||||||
|
|
||||||
cache.get(10)
|
|
||||||
assert len(cache) == 2
|
|
||||||
assert set(cache.cache) == {2, 4}
|
|
||||||
assert cache._remove_counter == 3
|
|
||||||
|
|
||||||
cache.put(6, 6)
|
|
||||||
assert len(cache) == 3
|
|
||||||
assert set(cache.cache) == {2, 4, 6}
|
|
||||||
assert 2 in cache
|
|
||||||
assert 4 in cache
|
|
||||||
assert 6 in cache
|
|
||||||
|
|
||||||
cache.remove_oldest()
|
|
||||||
assert len(cache) == 2
|
|
||||||
assert set(cache.cache) == {2, 6}
|
|
||||||
assert cache._remove_counter == 4
|
|
||||||
|
|
||||||
cache.clear()
|
|
||||||
assert len(cache) == 0
|
|
||||||
assert cache._remove_counter == 6
|
|
||||||
|
|
||||||
cache._remove_counter = 0
|
|
||||||
|
|
||||||
cache[1] = 1
|
|
||||||
assert len(cache) == 1
|
|
||||||
|
|
||||||
cache[1] = 1
|
|
||||||
assert len(cache) == 1
|
|
||||||
|
|
||||||
cache[2] = 2
|
|
||||||
assert len(cache) == 2
|
|
||||||
|
|
||||||
cache[3] = 3
|
|
||||||
assert len(cache) == 3
|
|
||||||
assert set(cache.cache) == {1, 2, 3}
|
|
||||||
|
|
||||||
cache[4] = 4
|
|
||||||
assert len(cache) == 3
|
|
||||||
assert set(cache.cache) == {2, 3, 4}
|
|
||||||
assert cache._remove_counter == 1
|
|
||||||
assert cache[2] == 2
|
|
||||||
|
|
||||||
cache[5] = 5
|
|
||||||
assert set(cache.cache) == {2, 4, 5}
|
|
||||||
assert cache._remove_counter == 2
|
|
||||||
|
|
||||||
del cache[5]
|
|
||||||
assert len(cache) == 2
|
|
||||||
assert set(cache.cache) == {2, 4}
|
|
||||||
assert cache._remove_counter == 3
|
|
||||||
|
|
||||||
cache.pop(10)
|
|
||||||
assert len(cache) == 2
|
|
||||||
assert set(cache.cache) == {2, 4}
|
|
||||||
assert cache._remove_counter == 3
|
|
||||||
|
|
||||||
cache[6] = 6
|
|
||||||
assert len(cache) == 3
|
|
||||||
assert set(cache.cache) == {2, 4, 6}
|
|
||||||
assert 2 in cache
|
|
||||||
assert 4 in cache
|
|
||||||
assert 6 in cache
|
|
||||||
|
|
||||||
|
|
||||||
# Unit tests for get_adapter_absolute_path
|
# Unit tests for get_adapter_absolute_path
|
||||||
@patch('os.path.isabs')
|
@patch('os.path.isabs')
|
||||||
def test_get_adapter_absolute_path_absolute(mock_isabs):
|
def test_get_adapter_absolute_path_absolute(mock_isabs):
|
||||||
|
@ -13,11 +13,11 @@ import torch
|
|||||||
from vllm_test_utils.monitor import monitor
|
from vllm_test_utils.monitor import monitor
|
||||||
|
|
||||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||||
from vllm.utils import (FlexibleArgumentParser, MemorySnapshot,
|
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
|
||||||
PlaceholderModule, StoreBoolean, bind_kv_cache,
|
MemorySnapshot, PlaceholderModule, StoreBoolean,
|
||||||
deprecate_kwargs, get_open_port, memory_profiling,
|
bind_kv_cache, deprecate_kwargs, get_open_port,
|
||||||
merge_async_iterators, sha256, supports_kw,
|
memory_profiling, merge_async_iterators, sha256,
|
||||||
swap_dict_values)
|
supports_kw, swap_dict_values)
|
||||||
|
|
||||||
from .utils import create_new_process_for_each_test, error_on_warning
|
from .utils import create_new_process_for_each_test, error_on_warning
|
||||||
|
|
||||||
@ -417,6 +417,129 @@ def test_bind_kv_cache_pp():
|
|||||||
assert ctx['layers.0.self_attn'].kv_cache[1] is kv_cache[1][0]
|
assert ctx['layers.0.self_attn'].kv_cache[1] is kv_cache[1][0]
|
||||||
|
|
||||||
|
|
||||||
|
class TestLRUCache(LRUCache):
|
||||||
|
|
||||||
|
def _on_remove(self, key, value):
|
||||||
|
if not hasattr(self, "_remove_counter"):
|
||||||
|
self._remove_counter = 0
|
||||||
|
self._remove_counter += 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_lru_cache():
|
||||||
|
cache = TestLRUCache(3)
|
||||||
|
assert cache.stat() == CacheInfo(hits=0, total=0)
|
||||||
|
assert cache.stat(delta=True) == CacheInfo(hits=0, total=0)
|
||||||
|
|
||||||
|
cache.put(1, 1)
|
||||||
|
assert len(cache) == 1
|
||||||
|
|
||||||
|
cache.put(1, 1)
|
||||||
|
assert len(cache) == 1
|
||||||
|
|
||||||
|
cache.put(2, 2)
|
||||||
|
assert len(cache) == 2
|
||||||
|
|
||||||
|
cache.put(3, 3)
|
||||||
|
assert len(cache) == 3
|
||||||
|
assert set(cache.cache) == {1, 2, 3}
|
||||||
|
|
||||||
|
cache.put(4, 4)
|
||||||
|
assert len(cache) == 3
|
||||||
|
assert set(cache.cache) == {2, 3, 4}
|
||||||
|
assert cache._remove_counter == 1
|
||||||
|
|
||||||
|
assert cache.get(2) == 2
|
||||||
|
assert cache.stat() == CacheInfo(hits=1, total=1)
|
||||||
|
assert cache.stat(delta=True) == CacheInfo(hits=1, total=1)
|
||||||
|
|
||||||
|
assert cache[2] == 2
|
||||||
|
assert cache.stat() == CacheInfo(hits=2, total=2)
|
||||||
|
assert cache.stat(delta=True) == CacheInfo(hits=1, total=1)
|
||||||
|
|
||||||
|
cache.put(5, 5)
|
||||||
|
assert set(cache.cache) == {2, 4, 5}
|
||||||
|
assert cache._remove_counter == 2
|
||||||
|
|
||||||
|
assert cache.pop(5) == 5
|
||||||
|
assert len(cache) == 2
|
||||||
|
assert set(cache.cache) == {2, 4}
|
||||||
|
assert cache._remove_counter == 3
|
||||||
|
|
||||||
|
assert cache.get(-1) is None
|
||||||
|
assert cache.stat() == CacheInfo(hits=2, total=3)
|
||||||
|
assert cache.stat(delta=True) == CacheInfo(hits=0, total=1)
|
||||||
|
|
||||||
|
cache.pop(10)
|
||||||
|
assert len(cache) == 2
|
||||||
|
assert set(cache.cache) == {2, 4}
|
||||||
|
assert cache._remove_counter == 3
|
||||||
|
|
||||||
|
cache.get(10)
|
||||||
|
assert len(cache) == 2
|
||||||
|
assert set(cache.cache) == {2, 4}
|
||||||
|
assert cache._remove_counter == 3
|
||||||
|
|
||||||
|
cache.put(6, 6)
|
||||||
|
assert len(cache) == 3
|
||||||
|
assert set(cache.cache) == {2, 4, 6}
|
||||||
|
assert 2 in cache
|
||||||
|
assert 4 in cache
|
||||||
|
assert 6 in cache
|
||||||
|
|
||||||
|
cache.remove_oldest()
|
||||||
|
assert len(cache) == 2
|
||||||
|
assert set(cache.cache) == {2, 6}
|
||||||
|
assert cache._remove_counter == 4
|
||||||
|
|
||||||
|
cache.clear()
|
||||||
|
assert len(cache) == 0
|
||||||
|
assert cache._remove_counter == 6
|
||||||
|
assert cache.stat() == CacheInfo(hits=0, total=0)
|
||||||
|
assert cache.stat(delta=True) == CacheInfo(hits=0, total=0)
|
||||||
|
|
||||||
|
cache._remove_counter = 0
|
||||||
|
|
||||||
|
cache[1] = 1
|
||||||
|
assert len(cache) == 1
|
||||||
|
|
||||||
|
cache[1] = 1
|
||||||
|
assert len(cache) == 1
|
||||||
|
|
||||||
|
cache[2] = 2
|
||||||
|
assert len(cache) == 2
|
||||||
|
|
||||||
|
cache[3] = 3
|
||||||
|
assert len(cache) == 3
|
||||||
|
assert set(cache.cache) == {1, 2, 3}
|
||||||
|
|
||||||
|
cache[4] = 4
|
||||||
|
assert len(cache) == 3
|
||||||
|
assert set(cache.cache) == {2, 3, 4}
|
||||||
|
assert cache._remove_counter == 1
|
||||||
|
assert cache[2] == 2
|
||||||
|
|
||||||
|
cache[5] = 5
|
||||||
|
assert set(cache.cache) == {2, 4, 5}
|
||||||
|
assert cache._remove_counter == 2
|
||||||
|
|
||||||
|
del cache[5]
|
||||||
|
assert len(cache) == 2
|
||||||
|
assert set(cache.cache) == {2, 4}
|
||||||
|
assert cache._remove_counter == 3
|
||||||
|
|
||||||
|
cache.pop(10)
|
||||||
|
assert len(cache) == 2
|
||||||
|
assert set(cache.cache) == {2, 4}
|
||||||
|
assert cache._remove_counter == 3
|
||||||
|
|
||||||
|
cache[6] = 6
|
||||||
|
assert len(cache) == 3
|
||||||
|
assert set(cache.cache) == {2, 4, 6}
|
||||||
|
assert 2 in cache
|
||||||
|
assert 4 in cache
|
||||||
|
assert 6 in cache
|
||||||
|
|
||||||
|
|
||||||
def test_placeholder_module_error_handling():
|
def test_placeholder_module_error_handling():
|
||||||
placeholder = PlaceholderModule("placeholder_1234")
|
placeholder = PlaceholderModule("placeholder_1234")
|
||||||
|
|
||||||
|
@ -236,6 +236,12 @@ class CacheInfo(NamedTuple):
|
|||||||
|
|
||||||
return self.hits / self.total
|
return self.hits / self.total
|
||||||
|
|
||||||
|
def __sub__(self, other: CacheInfo):
|
||||||
|
return CacheInfo(
|
||||||
|
hits=self.hits - other.hits,
|
||||||
|
total=self.total - other.total,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
|
class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
|
||||||
|
|
||||||
@ -243,15 +249,26 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
|
|||||||
capacity: float,
|
capacity: float,
|
||||||
getsizeof: Optional[Callable[[_V], float]] = None):
|
getsizeof: Optional[Callable[[_V], float]] = None):
|
||||||
super().__init__(capacity, getsizeof)
|
super().__init__(capacity, getsizeof)
|
||||||
|
|
||||||
self.pinned_items = set[_K]()
|
self.pinned_items = set[_K]()
|
||||||
self.capacity = capacity
|
|
||||||
|
|
||||||
self._hits = 0
|
self._hits = 0
|
||||||
self._total = 0
|
self._total = 0
|
||||||
|
self._last_info = CacheInfo(hits=0, total=0)
|
||||||
|
|
||||||
|
def __getitem__(self, key: _K, *, update_info: bool = True) -> _V:
|
||||||
|
value = super().__getitem__(key)
|
||||||
|
|
||||||
|
if update_info:
|
||||||
|
self._hits += 1
|
||||||
|
self._total += 1
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
def __delitem__(self, key: _K) -> None:
|
def __delitem__(self, key: _K) -> None:
|
||||||
run_on_remove = key in self
|
run_on_remove = key in self
|
||||||
value = self.__getitem__(key)
|
value = self.__getitem__(key,
|
||||||
|
update_info=False) # type: ignore[call-arg]
|
||||||
super().__delitem__(key)
|
super().__delitem__(key)
|
||||||
if key in self.pinned_items:
|
if key in self.pinned_items:
|
||||||
# Todo: add warning to inform that del pinned item
|
# Todo: add warning to inform that del pinned item
|
||||||
@ -271,8 +288,32 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
|
|||||||
"""Return the internal order dictionary (read-only)."""
|
"""Return the internal order dictionary (read-only)."""
|
||||||
return MappingProxyType(self._LRUCache__order) # type: ignore
|
return MappingProxyType(self._LRUCache__order) # type: ignore
|
||||||
|
|
||||||
def stat(self) -> CacheInfo:
|
@property
|
||||||
return CacheInfo(hits=self._hits, total=self._total)
|
def capacity(self) -> float:
|
||||||
|
return self.maxsize
|
||||||
|
|
||||||
|
@property
|
||||||
|
def usage(self) -> float:
|
||||||
|
if self.maxsize == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return self.currsize / self.maxsize
|
||||||
|
|
||||||
|
def stat(self, *, delta: bool = False) -> CacheInfo:
|
||||||
|
"""
|
||||||
|
Gets the cumulative number of hits and queries against this cache.
|
||||||
|
|
||||||
|
If :code:`delta=True`, instead gets these statistics
|
||||||
|
since the last call that also passed :code:`delta=True`.
|
||||||
|
"""
|
||||||
|
info = CacheInfo(hits=self._hits, total=self._total)
|
||||||
|
|
||||||
|
if delta:
|
||||||
|
info_delta = info - self._last_info
|
||||||
|
self._last_info = info
|
||||||
|
info = info_delta
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
def touch(self, key: _K) -> None:
|
def touch(self, key: _K) -> None:
|
||||||
self._LRUCache__update(key) # type: ignore
|
self._LRUCache__update(key) # type: ignore
|
||||||
@ -292,7 +333,8 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
|
|||||||
_T]] = None) -> Optional[Union[_V, _T]]:
|
_T]] = None) -> Optional[Union[_V, _T]]:
|
||||||
value: Optional[Union[_V, _T]]
|
value: Optional[Union[_V, _T]]
|
||||||
if key in self:
|
if key in self:
|
||||||
value = self.__getitem__(key)
|
value = self.__getitem__(
|
||||||
|
key, update_info=False) # type: ignore[call-arg]
|
||||||
|
|
||||||
self._hits += 1
|
self._hits += 1
|
||||||
else:
|
else:
|
||||||
@ -317,8 +359,9 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
|
|||||||
if key not in self:
|
if key not in self:
|
||||||
return default
|
return default
|
||||||
|
|
||||||
value = self[key]
|
value = self.__getitem__(key,
|
||||||
del self[key]
|
update_info=False) # type: ignore[call-arg]
|
||||||
|
self.__delitem__(key)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def put(self, key: _K, value: _V) -> None:
|
def put(self, key: _K, value: _V) -> None:
|
||||||
@ -353,10 +396,6 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
|
|||||||
while self.currsize > self.capacity:
|
while self.currsize > self.capacity:
|
||||||
self.remove_oldest()
|
self.remove_oldest()
|
||||||
|
|
||||||
def clear(self) -> None:
|
|
||||||
while len(self) > 0:
|
|
||||||
self.remove_oldest(remove_pinned=True)
|
|
||||||
|
|
||||||
def popitem(self, remove_pinned: bool = False):
|
def popitem(self, remove_pinned: bool = False):
|
||||||
"""Remove and return the `(key, value)` pair least recently used."""
|
"""Remove and return the `(key, value)` pair least recently used."""
|
||||||
if not remove_pinned:
|
if not remove_pinned:
|
||||||
@ -372,6 +411,14 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
|
|||||||
value = self.pop(cast(_K, lru_key))
|
value = self.pop(cast(_K, lru_key))
|
||||||
return (lru_key, value)
|
return (lru_key, value)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
while len(self) > 0:
|
||||||
|
self.remove_oldest(remove_pinned=True)
|
||||||
|
|
||||||
|
self._hits = 0
|
||||||
|
self._total = 0
|
||||||
|
self._last_info = CacheInfo(hits=0, total=0)
|
||||||
|
|
||||||
|
|
||||||
class PyObjectCache:
|
class PyObjectCache:
|
||||||
"""Used to cache python objects to avoid object allocations
|
"""Used to cache python objects to avoid object allocations
|
||||||
|
@ -50,7 +50,7 @@ class MirroredProcessingCache:
|
|||||||
|
|
||||||
full_mm_inputs = list[Optional[MultiModalKwargs]]()
|
full_mm_inputs = list[Optional[MultiModalKwargs]]()
|
||||||
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
|
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
|
||||||
if mm_hash in self.mm_cache:
|
if self.mm_cache.get(mm_hash) is not None:
|
||||||
mm_input = None
|
mm_input = None
|
||||||
else:
|
else:
|
||||||
self.mm_cache[mm_hash] = mm_input
|
self.mm_cache[mm_hash] = mm_input
|
||||||
|
Loading…
x
Reference in New Issue
Block a user