diff --git a/tests/core/block/test_prefix_caching_block.py b/tests/core/block/test_prefix_caching_block.py index 5fb8ec06..c2226870 100644 --- a/tests/core/block/test_prefix_caching_block.py +++ b/tests/core/block/test_prefix_caching_block.py @@ -682,6 +682,32 @@ class TestPrefixCachingBlockAllocator: assert new_block[0].block_id == last_block_id + # Test case for cache mertics + @staticmethod + def test_metric(): + block_size = 16 + allocator = PrefixCachingBlockAllocator(num_blocks=4, + block_size=block_size) + # Test when no query (0/0) + assert allocator.get_prefix_cache_hit_rate() == 0.0 + + token_ids = list(range(block_size)) + allocator.allocate_immutable_block(prev_block=None, + token_ids=token_ids) + # Test 0/1 hit rate + assert allocator.get_prefix_cache_hit_rate() == 0.0 + + allocator.allocate_immutable_block(prev_block=None, + token_ids=token_ids) + # Test 1/2 hit rate + assert allocator.get_prefix_cache_hit_rate() == 0.5 + + # Test more than one block + for _ in range(2, 1005): + allocator.allocate_immutable_block(prev_block=None, + token_ids=token_ids) + assert allocator.get_prefix_cache_hit_rate() > 0.99 + @staticmethod def create_immutable_chain( block_size: int, diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index 9821dbd0..2dff84b8 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -34,6 +34,9 @@ def test_block_allocator( assert (first_block == second_block) assert (second_block.ref_count == 2) + # Check metric: 1 hit of 2 queries + assert block_allocator.get_prefix_cache_hit_rate() == 0.5 + # Free the first_block and confirm that the ref_count is correctly # decremented on the second block block_allocator.free(first_block) @@ -48,6 +51,10 @@ def test_block_allocator( assert (first_block == second_block) assert (first_block.block_hash == block_hash) + # Allocate one more time to get 3/4 hit rate for easy checking + block_allocator.allocate(block_hash, 0) + assert block_allocator.get_prefix_cache_hit_rate() == 0.75 + @pytest.mark.parametrize("num_blocks", [16]) def test_eviction(num_blocks: int, ): diff --git a/vllm/core/block/common.py b/vllm/core/block/common.py index 1e808e21..eb190adf 100644 --- a/vllm/core/block/common.py +++ b/vllm/core/block/common.py @@ -1,4 +1,5 @@ from collections import deque +from dataclasses import dataclass from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple from vllm.core.block.interfaces import Block, BlockAllocator @@ -282,6 +283,58 @@ class BlockList: return self._block_ids +@dataclass +class CacheMetricData: + """A utility dataclass to maintain cache metric. + To avoid overflow, we maintain the hit rate in block granularity, so that + we can maintain a single hit rate for n_completed_block x block_size, + and calculate the real time hit rate by the following: + BS = The number of queries per block. + nB = The number of completed blocks. + HR = hit rate of (nB x BS) queries. + Q = current number of queries (< BS). + H = current number of hits (< BS). + hit rate = ((HR x nB) + (H / Q) x (Q / BS)) / (nB + Q / BS) + """ + num_completed_blocks: int = 0 + completed_block_cache_hit_rate: float = 0.0 + num_incompleted_block_queries: int = 0 + num_incompleted_block_hit: int = 0 + block_size: int = 1000 + + def query(self, hit: bool): + self.num_incompleted_block_queries += 1 + self.num_incompleted_block_hit += 1 if hit else 0 + + # When a block is completed, update the cache hit rate + # and reset the incomplete numbers. + if self.num_incompleted_block_queries == self.block_size: + hit_rate = (self.num_incompleted_block_hit / + self.num_incompleted_block_queries) + self.completed_block_cache_hit_rate = ( + self.completed_block_cache_hit_rate * self.num_completed_blocks + + hit_rate) / (self.num_completed_blocks + 1) + self.num_incompleted_block_queries = 0 + self.num_incompleted_block_hit = 0 + self.num_completed_blocks += 1 + + def get_hit_rate(self): + incomplete_ratio = self.num_incompleted_block_queries / self.block_size + total_blocks = self.num_completed_blocks + incomplete_ratio + if total_blocks == 0: + return 0.0 + + completed_block_hit, incompleted_block_hit = 0.0, 0.0 + if self.num_completed_blocks > 0: + completed_block_hit = (self.completed_block_cache_hit_rate * + self.num_completed_blocks) + if self.num_incompleted_block_queries > 0: + incompleted_hit_rate = (self.num_incompleted_block_hit / + self.num_incompleted_block_queries) + incompleted_block_hit = (incompleted_hit_rate * incomplete_ratio) + return (completed_block_hit + incompleted_block_hit) / total_blocks + + def get_all_blocks_recursively(last_block: Block) -> List[Block]: """Retrieves all the blocks in a sequence starting from the last block. diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 5287cd9c..c6330df2 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -323,6 +323,11 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator): def all_block_ids(self) -> FrozenSet[int]: return frozenset(self._block_ids_to_allocator.keys()) + def get_prefix_cache_hit_rate(self, device: Device) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + assert device in self._allocators + return self._allocators[device].get_prefix_cache_hit_rate() + def get_and_reset_swaps(self) -> List[Tuple[int, int]]: """Returns and clears the mapping of source to destination block IDs. Will be called after every swapping operations for now, and after every diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index ab39832b..f26bc761 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -186,6 +186,11 @@ class BlockAllocator(ABC): num_lookahead_slots: int = 0) -> int: pass + @abstractmethod + def get_prefix_cache_hit_rate(self) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass + class NoFreeBlocksError(ValueError): pass @@ -278,3 +283,8 @@ class DeviceAwareBlockAllocator(ABC): There is at most one null block per allocator. """ pass + + @abstractmethod + def get_prefix_cache_hit_rate(self, device: Device) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index 14a62c2e..1643fd69 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -341,6 +341,9 @@ class NaiveBlockAllocator(BlockAllocator): block.block_id = block_id # Assign block_id + def get_prefix_cache_hit_rate(self) -> float: + return -1 + class NaiveBlock(Block): """An implementation of the Block class that does not support prefix diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index e145eeba..432a6651 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -1,9 +1,8 @@ """Token blocks.""" - from os.path import commonprefix from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple -from vllm.core.block.common import (CopyOnWriteTracker, +from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker, get_all_blocks_recursively) from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device from vllm.core.block.naive_block import (BlockPool, NaiveBlock, @@ -107,6 +106,8 @@ class PrefixCachingBlockAllocator(BlockAllocator): self._cow_tracker = CopyOnWriteTracker( refcounter=self._refcounter.as_readonly()) + self.metric_data = CacheMetricData() + # Implements Block.Factory. def _create_block( self, @@ -155,9 +156,11 @@ class PrefixCachingBlockAllocator(BlockAllocator): cached_block_id = self._cached_blocks.get(block.content_hash, None) if cached_block_id is not None: + self.metric_data.query(hit=True) block.block_id = cached_block_id self._incr_refcount_cached_block(block) return block + self.metric_data.query(hit=False) self._block_pool.free_block(block) # No cached block => Allocate a new block @@ -404,6 +407,9 @@ class PrefixCachingBlockAllocator(BlockAllocator): def all_block_ids(self) -> FrozenSet[int]: return self._hashless_allocator.all_block_ids + def get_prefix_cache_hit_rate(self) -> float: + return self.metric_data.get_hit_rate() + def is_block_cached(self, block: Block) -> bool: assert block.content_hash is not None if block.content_hash in self._cached_blocks: diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index ad26d3c5..0af04399 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -8,6 +8,7 @@ from typing import Sequence as GenericSequence from typing import Set, Tuple from vllm.block import BlockTable, PhysicalTokenBlock +from vllm.core.block.common import CacheMetricData from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor from vllm.core.interfaces import AllocStatus, BlockSpaceManager @@ -60,6 +61,11 @@ class BlockAllocatorBase(ABC): def update_hash(self, block_hash: int, block: PhysicalTokenBlock): pass + @abstractmethod + def get_prefix_cache_hit_rate(self) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass + class CachedBlockAllocator(BlockAllocatorBase): """Manages free physical token blocks for a device. @@ -85,6 +91,8 @@ class CachedBlockAllocator(BlockAllocatorBase): self.default_hash_ctr = count() + self.cache_metric_data = CacheMetricData() + def allocate_block(self, block_hash: int, num_hashed_tokens: int) -> PhysicalTokenBlock: if self.current_num_blocks == self.num_blocks: @@ -105,15 +113,17 @@ class CachedBlockAllocator(BlockAllocatorBase): num_hashed_tokens: int = 0) -> PhysicalTokenBlock: if block_hash is None: block_hash = next(self.default_hash_ctr) + if block_hash in self.evictor: assert block_hash not in self.cached_blocks block = self.evictor.remove(block_hash) assert block.ref_count == 0 self.cached_blocks[block_hash] = block - block.ref_count += 1 - assert block.block_hash == block_hash - return block - if block_hash not in self.cached_blocks: + + if block_hash in self.cached_blocks: + self.cache_metric_data.query(hit=True) + else: + self.cache_metric_data.query(hit=False) self.cached_blocks[block_hash] = self.allocate_block( block_hash, num_hashed_tokens) block = self.cached_blocks[block_hash] @@ -150,6 +160,9 @@ class CachedBlockAllocator(BlockAllocatorBase): del self.cached_blocks[old_hash] self.cached_blocks[block_hash] = block + def get_prefix_cache_hit_rate(self) -> float: + return self.cache_metric_data.get_hit_rate() + class UncachedBlockAllocator(BlockAllocatorBase): """Manages free physical token blocks for a device. @@ -209,6 +222,9 @@ class UncachedBlockAllocator(BlockAllocatorBase): raise NotImplementedError( "Invalid codepath for uncached block allocator.") + def get_prefix_cache_hit_rate(self) -> float: + return -1 + class BlockSpaceManagerV1(BlockSpaceManager): """Manages the mapping between logical and physical token blocks.""" @@ -705,3 +721,10 @@ class BlockSpaceManagerV1(BlockSpaceManager): if self.enable_caching: for seq in seq_group.get_seqs(): self.compute_full_blocks_in_seq(seq) + + def get_prefix_cache_hit_rate(self, device: Device) -> float: + if device == Device.GPU: + return self.gpu_allocator.get_prefix_cache_hit_rate() + if device == Device.CPU: + return self.cpu_allocator.get_prefix_cache_hit_rate() + raise ValueError(f"Invalid device: {device}") diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index b48ea1b1..b7d9451f 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -441,6 +441,9 @@ class BlockSpaceManagerV2(BlockSpaceManager): def get_num_free_cpu_blocks(self) -> int: return self.block_allocator.get_num_free_blocks(Device.CPU) + def get_prefix_cache_hit_rate(self, device: Device) -> float: + return self.block_allocator.get_prefix_cache_hit_rate(device) + def _can_swap(self, seq_group: SequenceGroup, device: Device, diff --git a/vllm/core/embedding_model_block_manager.py b/vllm/core/embedding_model_block_manager.py index f2d67306..3d864a73 100644 --- a/vllm/core/embedding_model_block_manager.py +++ b/vllm/core/embedding_model_block_manager.py @@ -2,6 +2,7 @@ from typing import List, Tuple from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.sequence import Sequence, SequenceGroup +from vllm.utils import Device class EmbeddingModelBlockSpaceManager(BlockSpaceManager): @@ -81,3 +82,6 @@ class EmbeddingModelBlockSpaceManager(BlockSpaceManager): def mark_blocks_as_computed(self, seq_group: SequenceGroup): pass + + def get_prefix_cache_hit_rate(self, device: Device) -> float: + return -1 diff --git a/vllm/core/evictor_v2.py b/vllm/core/evictor_v2.py index 5b1a208b..0b943e6e 100644 --- a/vllm/core/evictor_v2.py +++ b/vllm/core/evictor_v2.py @@ -85,19 +85,21 @@ class LRUEvictor(Evictor): if len(self.free_table) == 0: raise ValueError("No usable cache memory left") - evicted_block = next(iter(self.free_table.values())) - evicted_block_id = next(iter(self.free_table.keys())) + evicted_block, evicted_block_id = None, None # The blocks with the lowest timestamps should be placed consecutively # at the start of OrderedDict. Loop through all these blocks to # find the one with maximum number of hashed tokens. for _id, block in self.free_table.items(): + if evicted_block is None: + evicted_block, evicted_block_id = block, _id + continue if evicted_block.last_accessed < block.last_accessed: break - if (evicted_block.last_accessed == block.last_accessed and - evicted_block.num_hashed_tokens < block.num_hashed_tokens): - evicted_block = block - evicted_block_id = _id + if evicted_block.num_hashed_tokens < block.num_hashed_tokens: + evicted_block, evicted_block_id = block, _id + assert evicted_block is not None + assert evicted_block_id is not None self.free_table.pop(evicted_block_id) return evicted_block_id, evicted_block.content_hash @@ -110,7 +112,6 @@ class LRUEvictor(Evictor): def update(self, block_id: int, last_accessed: float): self.free_table[block_id].last_accessed = last_accessed - self.free_table.move_to_end(block_id) def remove(self, block_id: int): if block_id not in self.free_table: diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 8759ee06..becd0d2e 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence from typing import Tuple from vllm.sequence import Sequence, SequenceGroup +from vllm.utils import Device class AllocStatus(enum.Enum): @@ -116,3 +117,8 @@ class BlockSpaceManager(ABC): @abstractmethod def mark_blocks_as_computed(self, seq_group: SequenceGroup): pass + + @abstractmethod + def get_prefix_cache_hit_rate(self, device: Device) -> float: + """Prefix cache hit rate. -1 means not supported or disabled.""" + pass diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 802359d2..3b716e32 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -14,7 +14,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceGroupMetadataDelta, SequenceStatus) -from vllm.utils import PyObjectCache +from vllm.utils import Device, PyObjectCache logger = init_logger(__name__) @@ -447,6 +447,9 @@ class Scheduler: return len(self.waiting) != 0 or len(self.running) != 0 or len( self.swapped) != 0 + def get_prefix_cache_hit_rate(self, device: Device) -> float: + return self.block_manager.get_prefix_cache_hit_rate(device) + def get_num_unfinished_seq_groups(self) -> int: return len(self.waiting) + len(self.running) + len(self.swapped) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index fcf45a38..36cb6ce7 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -47,7 +47,7 @@ from vllm.transformers_utils.tokenizer_group import ( AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) -from vllm.utils import Counter +from vllm.utils import Counter, Device from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) @@ -1390,6 +1390,13 @@ class LLMEngine: for scheduler in self.scheduler) cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu) + # Prefix Cache Hit Rate. Note that we always use + # the cache hit rate of the first virtual engine. + cpu_prefix_cache_hit_rate = self.scheduler[ + 0].get_prefix_cache_hit_rate(Device.CPU) + gpu_prefix_cache_hit_rate = self.scheduler[ + 0].get_prefix_cache_hit_rate(Device.GPU) + # Iteration stats num_prompt_tokens_iter = 0 num_generation_tokens_iter = 0 @@ -1498,6 +1505,9 @@ class LLMEngine: # KV Cache Usage in % gpu_cache_usage_sys=gpu_cache_usage_sys, cpu_cache_usage_sys=cpu_cache_usage_sys, + # Prefix Cache Hit Rate + cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate, + gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate, # Iteration stats num_prompt_tokens_iter=num_prompt_tokens_iter, diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 1071786c..74277cae 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -71,6 +71,17 @@ class Metrics: documentation="CPU KV-cache usage. 1 means 100 percent usage.", labelnames=labelnames, multiprocess_mode="sum") + # Prefix caching block hit rate + self.gauge_cpu_prefix_cache_hit_rate = self._gauge_cls( + name="vllm:cpu_prefix_cache_hit_rate", + documentation="CPU prefix cache block hit rate.", + labelnames=labelnames, + multiprocess_mode="sum") + self.gauge_gpu_prefix_cache_hit_rate = self._gauge_cls( + name="vllm:gpu_prefix_cache_hit_rate", + documentation="GPU prefix cache block hit rate.", + labelnames=labelnames, + multiprocess_mode="sum") # Iteration stats self.counter_num_preemption = self._counter_cls( @@ -351,7 +362,13 @@ class LoggingStatLogger(StatLoggerBase): stats.gpu_cache_usage_sys * 100, stats.cpu_cache_usage_sys * 100, ) - + if (stats.cpu_prefix_cache_hit_rate >= 0 + or stats.gpu_prefix_cache_hit_rate >= 0): + logger.info( + "Prefix cache hit rate: GPU: %.2f%%, CPU: %.2f%%", + stats.gpu_prefix_cache_hit_rate * 100, + stats.cpu_prefix_cache_hit_rate * 100, + ) if self.spec_decode_metrics is not None: logger.info( self._format_spec_decode_metrics_str( @@ -423,6 +440,10 @@ class PrometheusStatLogger(StatLoggerBase): stats.gpu_cache_usage_sys) self._log_gauge(self.metrics.gauge_cpu_cache_usage, stats.cpu_cache_usage_sys) + self._log_gauge(self.metrics.gauge_cpu_prefix_cache_hit_rate, + stats.cpu_prefix_cache_hit_rate) + self._log_gauge(self.metrics.gauge_gpu_prefix_cache_hit_rate, + stats.gpu_prefix_cache_hit_rate) # Iteration level data self._log_counter(self.metrics.counter_num_preemption, diff --git a/vllm/engine/metrics_types.py b/vllm/engine/metrics_types.py index 7449aafc..1eccb235 100644 --- a/vllm/engine/metrics_types.py +++ b/vllm/engine/metrics_types.py @@ -32,6 +32,9 @@ class Stats: # KV Cache Usage in % gpu_cache_usage_sys: float cpu_cache_usage_sys: float + # Prefix caching block hit rate + cpu_prefix_cache_hit_rate: float + gpu_prefix_cache_hit_rate: float # Iteration stats (should have _iter suffix) num_prompt_tokens_iter: int