[MISC] Add prefix cache hit rate to metrics (#7606)

This commit is contained in:
Cody Yu 2024-08-19 11:52:07 -07:00 committed by GitHub
parent df845b2b46
commit 3ac50b47d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 200 additions and 16 deletions

View File

@ -682,6 +682,32 @@ class TestPrefixCachingBlockAllocator:
assert new_block[0].block_id == last_block_id
# Test case for cache mertics
@staticmethod
def test_metric():
block_size = 16
allocator = PrefixCachingBlockAllocator(num_blocks=4,
block_size=block_size)
# Test when no query (0/0)
assert allocator.get_prefix_cache_hit_rate() == 0.0
token_ids = list(range(block_size))
allocator.allocate_immutable_block(prev_block=None,
token_ids=token_ids)
# Test 0/1 hit rate
assert allocator.get_prefix_cache_hit_rate() == 0.0
allocator.allocate_immutable_block(prev_block=None,
token_ids=token_ids)
# Test 1/2 hit rate
assert allocator.get_prefix_cache_hit_rate() == 0.5
# Test more than one block
for _ in range(2, 1005):
allocator.allocate_immutable_block(prev_block=None,
token_ids=token_ids)
assert allocator.get_prefix_cache_hit_rate() > 0.99
@staticmethod
def create_immutable_chain(
block_size: int,

View File

@ -34,6 +34,9 @@ def test_block_allocator(
assert (first_block == second_block)
assert (second_block.ref_count == 2)
# Check metric: 1 hit of 2 queries
assert block_allocator.get_prefix_cache_hit_rate() == 0.5
# Free the first_block and confirm that the ref_count is correctly
# decremented on the second block
block_allocator.free(first_block)
@ -48,6 +51,10 @@ def test_block_allocator(
assert (first_block == second_block)
assert (first_block.block_hash == block_hash)
# Allocate one more time to get 3/4 hit rate for easy checking
block_allocator.allocate(block_hash, 0)
assert block_allocator.get_prefix_cache_hit_rate() == 0.75
@pytest.mark.parametrize("num_blocks", [16])
def test_eviction(num_blocks: int, ):

View File

@ -1,4 +1,5 @@
from collections import deque
from dataclasses import dataclass
from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple
from vllm.core.block.interfaces import Block, BlockAllocator
@ -282,6 +283,58 @@ class BlockList:
return self._block_ids
@dataclass
class CacheMetricData:
"""A utility dataclass to maintain cache metric.
To avoid overflow, we maintain the hit rate in block granularity, so that
we can maintain a single hit rate for n_completed_block x block_size,
and calculate the real time hit rate by the following:
BS = The number of queries per block.
nB = The number of completed blocks.
HR = hit rate of (nB x BS) queries.
Q = current number of queries (< BS).
H = current number of hits (< BS).
hit rate = ((HR x nB) + (H / Q) x (Q / BS)) / (nB + Q / BS)
"""
num_completed_blocks: int = 0
completed_block_cache_hit_rate: float = 0.0
num_incompleted_block_queries: int = 0
num_incompleted_block_hit: int = 0
block_size: int = 1000
def query(self, hit: bool):
self.num_incompleted_block_queries += 1
self.num_incompleted_block_hit += 1 if hit else 0
# When a block is completed, update the cache hit rate
# and reset the incomplete numbers.
if self.num_incompleted_block_queries == self.block_size:
hit_rate = (self.num_incompleted_block_hit /
self.num_incompleted_block_queries)
self.completed_block_cache_hit_rate = (
self.completed_block_cache_hit_rate * self.num_completed_blocks
+ hit_rate) / (self.num_completed_blocks + 1)
self.num_incompleted_block_queries = 0
self.num_incompleted_block_hit = 0
self.num_completed_blocks += 1
def get_hit_rate(self):
incomplete_ratio = self.num_incompleted_block_queries / self.block_size
total_blocks = self.num_completed_blocks + incomplete_ratio
if total_blocks == 0:
return 0.0
completed_block_hit, incompleted_block_hit = 0.0, 0.0
if self.num_completed_blocks > 0:
completed_block_hit = (self.completed_block_cache_hit_rate *
self.num_completed_blocks)
if self.num_incompleted_block_queries > 0:
incompleted_hit_rate = (self.num_incompleted_block_hit /
self.num_incompleted_block_queries)
incompleted_block_hit = (incompleted_hit_rate * incomplete_ratio)
return (completed_block_hit + incompleted_block_hit) / total_blocks
def get_all_blocks_recursively(last_block: Block) -> List[Block]:
"""Retrieves all the blocks in a sequence starting from the last block.

View File

@ -323,6 +323,11 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
def all_block_ids(self) -> FrozenSet[int]:
return frozenset(self._block_ids_to_allocator.keys())
def get_prefix_cache_hit_rate(self, device: Device) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
assert device in self._allocators
return self._allocators[device].get_prefix_cache_hit_rate()
def get_and_reset_swaps(self) -> List[Tuple[int, int]]:
"""Returns and clears the mapping of source to destination block IDs.
Will be called after every swapping operations for now, and after every

View File

@ -186,6 +186,11 @@ class BlockAllocator(ABC):
num_lookahead_slots: int = 0) -> int:
pass
@abstractmethod
def get_prefix_cache_hit_rate(self) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass
class NoFreeBlocksError(ValueError):
pass
@ -278,3 +283,8 @@ class DeviceAwareBlockAllocator(ABC):
There is at most one null block per allocator.
"""
pass
@abstractmethod
def get_prefix_cache_hit_rate(self, device: Device) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass

View File

@ -341,6 +341,9 @@ class NaiveBlockAllocator(BlockAllocator):
block.block_id = block_id # Assign block_id
def get_prefix_cache_hit_rate(self) -> float:
return -1
class NaiveBlock(Block):
"""An implementation of the Block class that does not support prefix

View File

@ -1,9 +1,8 @@
"""Token blocks."""
from os.path import commonprefix
from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple
from vllm.core.block.common import (CopyOnWriteTracker,
from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker,
get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
@ -107,6 +106,8 @@ class PrefixCachingBlockAllocator(BlockAllocator):
self._cow_tracker = CopyOnWriteTracker(
refcounter=self._refcounter.as_readonly())
self.metric_data = CacheMetricData()
# Implements Block.Factory.
def _create_block(
self,
@ -155,9 +156,11 @@ class PrefixCachingBlockAllocator(BlockAllocator):
cached_block_id = self._cached_blocks.get(block.content_hash, None)
if cached_block_id is not None:
self.metric_data.query(hit=True)
block.block_id = cached_block_id
self._incr_refcount_cached_block(block)
return block
self.metric_data.query(hit=False)
self._block_pool.free_block(block)
# No cached block => Allocate a new block
@ -404,6 +407,9 @@ class PrefixCachingBlockAllocator(BlockAllocator):
def all_block_ids(self) -> FrozenSet[int]:
return self._hashless_allocator.all_block_ids
def get_prefix_cache_hit_rate(self) -> float:
return self.metric_data.get_hit_rate()
def is_block_cached(self, block: Block) -> bool:
assert block.content_hash is not None
if block.content_hash in self._cached_blocks:

View File

@ -8,6 +8,7 @@ from typing import Sequence as GenericSequence
from typing import Set, Tuple
from vllm.block import BlockTable, PhysicalTokenBlock
from vllm.core.block.common import CacheMetricData
from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec
from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
@ -60,6 +61,11 @@ class BlockAllocatorBase(ABC):
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
pass
@abstractmethod
def get_prefix_cache_hit_rate(self) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass
class CachedBlockAllocator(BlockAllocatorBase):
"""Manages free physical token blocks for a device.
@ -85,6 +91,8 @@ class CachedBlockAllocator(BlockAllocatorBase):
self.default_hash_ctr = count()
self.cache_metric_data = CacheMetricData()
def allocate_block(self, block_hash: int,
num_hashed_tokens: int) -> PhysicalTokenBlock:
if self.current_num_blocks == self.num_blocks:
@ -105,15 +113,17 @@ class CachedBlockAllocator(BlockAllocatorBase):
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
if block_hash is None:
block_hash = next(self.default_hash_ctr)
if block_hash in self.evictor:
assert block_hash not in self.cached_blocks
block = self.evictor.remove(block_hash)
assert block.ref_count == 0
self.cached_blocks[block_hash] = block
block.ref_count += 1
assert block.block_hash == block_hash
return block
if block_hash not in self.cached_blocks:
if block_hash in self.cached_blocks:
self.cache_metric_data.query(hit=True)
else:
self.cache_metric_data.query(hit=False)
self.cached_blocks[block_hash] = self.allocate_block(
block_hash, num_hashed_tokens)
block = self.cached_blocks[block_hash]
@ -150,6 +160,9 @@ class CachedBlockAllocator(BlockAllocatorBase):
del self.cached_blocks[old_hash]
self.cached_blocks[block_hash] = block
def get_prefix_cache_hit_rate(self) -> float:
return self.cache_metric_data.get_hit_rate()
class UncachedBlockAllocator(BlockAllocatorBase):
"""Manages free physical token blocks for a device.
@ -209,6 +222,9 @@ class UncachedBlockAllocator(BlockAllocatorBase):
raise NotImplementedError(
"Invalid codepath for uncached block allocator.")
def get_prefix_cache_hit_rate(self) -> float:
return -1
class BlockSpaceManagerV1(BlockSpaceManager):
"""Manages the mapping between logical and physical token blocks."""
@ -705,3 +721,10 @@ class BlockSpaceManagerV1(BlockSpaceManager):
if self.enable_caching:
for seq in seq_group.get_seqs():
self.compute_full_blocks_in_seq(seq)
def get_prefix_cache_hit_rate(self, device: Device) -> float:
if device == Device.GPU:
return self.gpu_allocator.get_prefix_cache_hit_rate()
if device == Device.CPU:
return self.cpu_allocator.get_prefix_cache_hit_rate()
raise ValueError(f"Invalid device: {device}")

View File

@ -441,6 +441,9 @@ class BlockSpaceManagerV2(BlockSpaceManager):
def get_num_free_cpu_blocks(self) -> int:
return self.block_allocator.get_num_free_blocks(Device.CPU)
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_allocator.get_prefix_cache_hit_rate(device)
def _can_swap(self,
seq_group: SequenceGroup,
device: Device,

View File

@ -2,6 +2,7 @@ from typing import List, Tuple
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.sequence import Sequence, SequenceGroup
from vllm.utils import Device
class EmbeddingModelBlockSpaceManager(BlockSpaceManager):
@ -81,3 +82,6 @@ class EmbeddingModelBlockSpaceManager(BlockSpaceManager):
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
pass
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return -1

View File

@ -85,19 +85,21 @@ class LRUEvictor(Evictor):
if len(self.free_table) == 0:
raise ValueError("No usable cache memory left")
evicted_block = next(iter(self.free_table.values()))
evicted_block_id = next(iter(self.free_table.keys()))
evicted_block, evicted_block_id = None, None
# The blocks with the lowest timestamps should be placed consecutively
# at the start of OrderedDict. Loop through all these blocks to
# find the one with maximum number of hashed tokens.
for _id, block in self.free_table.items():
if evicted_block is None:
evicted_block, evicted_block_id = block, _id
continue
if evicted_block.last_accessed < block.last_accessed:
break
if (evicted_block.last_accessed == block.last_accessed and
evicted_block.num_hashed_tokens < block.num_hashed_tokens):
evicted_block = block
evicted_block_id = _id
if evicted_block.num_hashed_tokens < block.num_hashed_tokens:
evicted_block, evicted_block_id = block, _id
assert evicted_block is not None
assert evicted_block_id is not None
self.free_table.pop(evicted_block_id)
return evicted_block_id, evicted_block.content_hash
@ -110,7 +112,6 @@ class LRUEvictor(Evictor):
def update(self, block_id: int, last_accessed: float):
self.free_table[block_id].last_accessed = last_accessed
self.free_table.move_to_end(block_id)
def remove(self, block_id: int):
if block_id not in self.free_table:

View File

@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence
from typing import Tuple
from vllm.sequence import Sequence, SequenceGroup
from vllm.utils import Device
class AllocStatus(enum.Enum):
@ -116,3 +117,8 @@ class BlockSpaceManager(ABC):
@abstractmethod
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
pass
@abstractmethod
def get_prefix_cache_hit_rate(self, device: Device) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass

View File

@ -14,7 +14,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceGroupMetadataDelta,
SequenceStatus)
from vllm.utils import PyObjectCache
from vllm.utils import Device, PyObjectCache
logger = init_logger(__name__)
@ -447,6 +447,9 @@ class Scheduler:
return len(self.waiting) != 0 or len(self.running) != 0 or len(
self.swapped) != 0
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_manager.get_prefix_cache_hit_rate(device)
def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped)

View File

@ -47,7 +47,7 @@ from vllm.transformers_utils.tokenizer_group import (
AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter
from vllm.utils import Counter, Device
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
@ -1390,6 +1390,13 @@ class LLMEngine:
for scheduler in self.scheduler)
cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
# Prefix Cache Hit Rate. Note that we always use
# the cache hit rate of the first virtual engine.
cpu_prefix_cache_hit_rate = self.scheduler[
0].get_prefix_cache_hit_rate(Device.CPU)
gpu_prefix_cache_hit_rate = self.scheduler[
0].get_prefix_cache_hit_rate(Device.GPU)
# Iteration stats
num_prompt_tokens_iter = 0
num_generation_tokens_iter = 0
@ -1498,6 +1505,9 @@ class LLMEngine:
# KV Cache Usage in %
gpu_cache_usage_sys=gpu_cache_usage_sys,
cpu_cache_usage_sys=cpu_cache_usage_sys,
# Prefix Cache Hit Rate
cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
# Iteration stats
num_prompt_tokens_iter=num_prompt_tokens_iter,

View File

@ -71,6 +71,17 @@ class Metrics:
documentation="CPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames,
multiprocess_mode="sum")
# Prefix caching block hit rate
self.gauge_cpu_prefix_cache_hit_rate = self._gauge_cls(
name="vllm:cpu_prefix_cache_hit_rate",
documentation="CPU prefix cache block hit rate.",
labelnames=labelnames,
multiprocess_mode="sum")
self.gauge_gpu_prefix_cache_hit_rate = self._gauge_cls(
name="vllm:gpu_prefix_cache_hit_rate",
documentation="GPU prefix cache block hit rate.",
labelnames=labelnames,
multiprocess_mode="sum")
# Iteration stats
self.counter_num_preemption = self._counter_cls(
@ -351,7 +362,13 @@ class LoggingStatLogger(StatLoggerBase):
stats.gpu_cache_usage_sys * 100,
stats.cpu_cache_usage_sys * 100,
)
if (stats.cpu_prefix_cache_hit_rate >= 0
or stats.gpu_prefix_cache_hit_rate >= 0):
logger.info(
"Prefix cache hit rate: GPU: %.2f%%, CPU: %.2f%%",
stats.gpu_prefix_cache_hit_rate * 100,
stats.cpu_prefix_cache_hit_rate * 100,
)
if self.spec_decode_metrics is not None:
logger.info(
self._format_spec_decode_metrics_str(
@ -423,6 +440,10 @@ class PrometheusStatLogger(StatLoggerBase):
stats.gpu_cache_usage_sys)
self._log_gauge(self.metrics.gauge_cpu_cache_usage,
stats.cpu_cache_usage_sys)
self._log_gauge(self.metrics.gauge_cpu_prefix_cache_hit_rate,
stats.cpu_prefix_cache_hit_rate)
self._log_gauge(self.metrics.gauge_gpu_prefix_cache_hit_rate,
stats.gpu_prefix_cache_hit_rate)
# Iteration level data
self._log_counter(self.metrics.counter_num_preemption,

View File

@ -32,6 +32,9 @@ class Stats:
# KV Cache Usage in %
gpu_cache_usage_sys: float
cpu_cache_usage_sys: float
# Prefix caching block hit rate
cpu_prefix_cache_hit_rate: float
gpu_prefix_cache_hit_rate: float
# Iteration stats (should have _iter suffix)
num_prompt_tokens_iter: int