[Bugfix] Fix size calculation of processing cache (#15114)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
1fe0fd12d3
commit
3d446433ec
@ -7,15 +7,20 @@ from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import ProcessorMixin
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargs,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalSharedField)
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
|
||||
PromptIndexTargets, PromptInsertion,
|
||||
PromptReplacement, apply_text_matches,
|
||||
ProcessingCache, PromptIndexTargets,
|
||||
PromptInsertion, PromptReplacement,
|
||||
apply_text_matches,
|
||||
apply_token_matches,
|
||||
find_mm_placeholders,
|
||||
find_text_matches, find_token_matches,
|
||||
@ -890,6 +895,45 @@ def test_find_mm_placeholders(
|
||||
assert result == expected
|
||||
|
||||
|
||||
def _dummy_elem(modality: str, key: str, size: int):
|
||||
return MultiModalFieldElem(
|
||||
modality=modality,
|
||||
key=key,
|
||||
data=torch.empty((size, ), dtype=torch.int8),
|
||||
field=MultiModalSharedField(1),
|
||||
)
|
||||
|
||||
|
||||
def _dummy_item(modality: str, size_by_key: dict[str, int]):
|
||||
return MultiModalKwargsItem.from_elems([
|
||||
_dummy_elem(modality, key, size) for key, size in size_by_key.items()
|
||||
])
|
||||
|
||||
|
||||
def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]):
|
||||
return MultiModalKwargs.from_items([
|
||||
_dummy_item(modality, size_by_key)
|
||||
for modality, size_by_key in size_by_key_modality.items()
|
||||
])
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(
|
||||
("item", "expected_size"),
|
||||
[
|
||||
(_dummy_item("a", {"a1": 100}), 100),
|
||||
(_dummy_item("a", {"a1": 100, "a2": 110}), 210),
|
||||
(_dummy_kw({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501
|
||||
],
|
||||
)
|
||||
# yapf: enable
|
||||
def test_cache_item_size(item, expected_size):
|
||||
cache = ProcessingCache.get_lru_cache(2048, type(item))
|
||||
cache[""] = item
|
||||
|
||||
assert cache.currsize == expected_size
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
||||
@pytest.mark.parametrize(
|
||||
("limit", "num_supported", "is_valid"),
|
||||
|
@ -26,7 +26,7 @@ from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby
|
||||
from .hasher import MultiModalHasher
|
||||
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||
MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs,
|
||||
MultiModalKwargsItem, PlaceholderRange)
|
||||
MultiModalKwargsItem, NestedTensors, PlaceholderRange)
|
||||
from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
|
||||
MultiModalDataParser)
|
||||
|
||||
@ -853,33 +853,62 @@ class ProcessingCache:
|
||||
|
||||
@staticmethod
|
||||
def get_lru_cache(
|
||||
capacity_gb: int,
|
||||
capacity_gb: float,
|
||||
value_type: type[_V],
|
||||
*,
|
||||
debug: bool = False,
|
||||
) -> LRUCache[str, _V]:
|
||||
|
||||
def get_size(leaf: object) -> int:
|
||||
def get_leaf_size(leaf: object) -> int:
|
||||
# MultiModalKwargs is not a subclass of dict
|
||||
if isinstance(leaf, MultiModalKwargs):
|
||||
return get_item_size(leaf.data)
|
||||
|
||||
# MultiModalKwargsItem is not a subclass of dict
|
||||
if isinstance(leaf, MultiModalKwargsItem):
|
||||
leaf_data = {k: v.data for k, v in leaf.items()}
|
||||
return get_item_size(leaf_data)
|
||||
|
||||
# sys.getsizeof doesn't work for tensors
|
||||
if isinstance(leaf, torch.Tensor):
|
||||
return leaf.nbytes # sys.getsizeof doesn't work for tensors
|
||||
return leaf.nbytes
|
||||
|
||||
return sys.getsizeof(leaf)
|
||||
|
||||
return LRUCache[str, _V](
|
||||
GiB_bytes * capacity_gb,
|
||||
getsizeof=lambda x: json_reduce_leaves(
|
||||
def get_item_size(
|
||||
value: Union[MultiModalKwargs, MultiModalKwargsItem,
|
||||
Mapping[str, NestedTensors]]
|
||||
) -> int:
|
||||
size = json_reduce_leaves(
|
||||
lambda a, b: a + b,
|
||||
json_map_leaves(get_size, x),
|
||||
),
|
||||
)
|
||||
json_map_leaves(get_leaf_size, value),
|
||||
)
|
||||
|
||||
def __init__(self, capacity_gb: int) -> None:
|
||||
if debug:
|
||||
logger.debug("Calculated size of %s to be %.2f GiB",
|
||||
type(value), size / GiB_bytes)
|
||||
|
||||
return size
|
||||
|
||||
return LRUCache(GiB_bytes * capacity_gb, getsizeof=get_item_size)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
capacity_gb: float,
|
||||
*,
|
||||
debug_cache_hit_ratio_steps: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# DEBUG: Set to None to disable
|
||||
self.debug_cache_hit_ratio_steps: Optional[int] = None
|
||||
self.debug_cache_hit_ratio_steps = debug_cache_hit_ratio_steps
|
||||
self.debug_cache_hits = 0
|
||||
self.debug_cache_total = 0
|
||||
|
||||
self._cache = self.get_lru_cache(capacity_gb, MultiModalKwargsItem)
|
||||
self._cache = self.get_lru_cache(
|
||||
capacity_gb,
|
||||
MultiModalKwargsItem,
|
||||
debug=bool(debug_cache_hit_ratio_steps),
|
||||
)
|
||||
|
||||
def _maybe_log_cache_stats(self) -> None:
|
||||
steps = self.debug_cache_hit_ratio_steps
|
||||
@ -890,6 +919,9 @@ class ProcessingCache:
|
||||
if total > 0 and total % steps == 0:
|
||||
logger.debug("ProcessingCache: hit_ratio = %.2f",
|
||||
self.debug_cache_hits / total)
|
||||
logger.debug("ProcessingCache: size = %.2f / %.2f GiB",
|
||||
self._cache.currsize / GiB_bytes,
|
||||
self._cache.maxsize / GiB_bytes)
|
||||
|
||||
def get(
|
||||
self,
|
||||
|
Loading…
x
Reference in New Issue
Block a user