[MISC] Add prefix cache hit rate to metrics (#7606)
This commit is contained in:
parent
df845b2b46
commit
3ac50b47d0
@ -682,6 +682,32 @@ class TestPrefixCachingBlockAllocator:
|
||||
|
||||
assert new_block[0].block_id == last_block_id
|
||||
|
||||
# Test case for cache mertics
|
||||
@staticmethod
|
||||
def test_metric():
|
||||
block_size = 16
|
||||
allocator = PrefixCachingBlockAllocator(num_blocks=4,
|
||||
block_size=block_size)
|
||||
# Test when no query (0/0)
|
||||
assert allocator.get_prefix_cache_hit_rate() == 0.0
|
||||
|
||||
token_ids = list(range(block_size))
|
||||
allocator.allocate_immutable_block(prev_block=None,
|
||||
token_ids=token_ids)
|
||||
# Test 0/1 hit rate
|
||||
assert allocator.get_prefix_cache_hit_rate() == 0.0
|
||||
|
||||
allocator.allocate_immutable_block(prev_block=None,
|
||||
token_ids=token_ids)
|
||||
# Test 1/2 hit rate
|
||||
assert allocator.get_prefix_cache_hit_rate() == 0.5
|
||||
|
||||
# Test more than one block
|
||||
for _ in range(2, 1005):
|
||||
allocator.allocate_immutable_block(prev_block=None,
|
||||
token_ids=token_ids)
|
||||
assert allocator.get_prefix_cache_hit_rate() > 0.99
|
||||
|
||||
@staticmethod
|
||||
def create_immutable_chain(
|
||||
block_size: int,
|
||||
|
@ -34,6 +34,9 @@ def test_block_allocator(
|
||||
assert (first_block == second_block)
|
||||
assert (second_block.ref_count == 2)
|
||||
|
||||
# Check metric: 1 hit of 2 queries
|
||||
assert block_allocator.get_prefix_cache_hit_rate() == 0.5
|
||||
|
||||
# Free the first_block and confirm that the ref_count is correctly
|
||||
# decremented on the second block
|
||||
block_allocator.free(first_block)
|
||||
@ -48,6 +51,10 @@ def test_block_allocator(
|
||||
assert (first_block == second_block)
|
||||
assert (first_block.block_hash == block_hash)
|
||||
|
||||
# Allocate one more time to get 3/4 hit rate for easy checking
|
||||
block_allocator.allocate(block_hash, 0)
|
||||
assert block_allocator.get_prefix_cache_hit_rate() == 0.75
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_blocks", [16])
|
||||
def test_eviction(num_blocks: int, ):
|
||||
|
@ -1,4 +1,5 @@
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple
|
||||
|
||||
from vllm.core.block.interfaces import Block, BlockAllocator
|
||||
@ -282,6 +283,58 @@ class BlockList:
|
||||
return self._block_ids
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheMetricData:
|
||||
"""A utility dataclass to maintain cache metric.
|
||||
To avoid overflow, we maintain the hit rate in block granularity, so that
|
||||
we can maintain a single hit rate for n_completed_block x block_size,
|
||||
and calculate the real time hit rate by the following:
|
||||
BS = The number of queries per block.
|
||||
nB = The number of completed blocks.
|
||||
HR = hit rate of (nB x BS) queries.
|
||||
Q = current number of queries (< BS).
|
||||
H = current number of hits (< BS).
|
||||
hit rate = ((HR x nB) + (H / Q) x (Q / BS)) / (nB + Q / BS)
|
||||
"""
|
||||
num_completed_blocks: int = 0
|
||||
completed_block_cache_hit_rate: float = 0.0
|
||||
num_incompleted_block_queries: int = 0
|
||||
num_incompleted_block_hit: int = 0
|
||||
block_size: int = 1000
|
||||
|
||||
def query(self, hit: bool):
|
||||
self.num_incompleted_block_queries += 1
|
||||
self.num_incompleted_block_hit += 1 if hit else 0
|
||||
|
||||
# When a block is completed, update the cache hit rate
|
||||
# and reset the incomplete numbers.
|
||||
if self.num_incompleted_block_queries == self.block_size:
|
||||
hit_rate = (self.num_incompleted_block_hit /
|
||||
self.num_incompleted_block_queries)
|
||||
self.completed_block_cache_hit_rate = (
|
||||
self.completed_block_cache_hit_rate * self.num_completed_blocks
|
||||
+ hit_rate) / (self.num_completed_blocks + 1)
|
||||
self.num_incompleted_block_queries = 0
|
||||
self.num_incompleted_block_hit = 0
|
||||
self.num_completed_blocks += 1
|
||||
|
||||
def get_hit_rate(self):
|
||||
incomplete_ratio = self.num_incompleted_block_queries / self.block_size
|
||||
total_blocks = self.num_completed_blocks + incomplete_ratio
|
||||
if total_blocks == 0:
|
||||
return 0.0
|
||||
|
||||
completed_block_hit, incompleted_block_hit = 0.0, 0.0
|
||||
if self.num_completed_blocks > 0:
|
||||
completed_block_hit = (self.completed_block_cache_hit_rate *
|
||||
self.num_completed_blocks)
|
||||
if self.num_incompleted_block_queries > 0:
|
||||
incompleted_hit_rate = (self.num_incompleted_block_hit /
|
||||
self.num_incompleted_block_queries)
|
||||
incompleted_block_hit = (incompleted_hit_rate * incomplete_ratio)
|
||||
return (completed_block_hit + incompleted_block_hit) / total_blocks
|
||||
|
||||
|
||||
def get_all_blocks_recursively(last_block: Block) -> List[Block]:
|
||||
"""Retrieves all the blocks in a sequence starting from the last block.
|
||||
|
||||
|
@ -323,6 +323,11 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
|
||||
def all_block_ids(self) -> FrozenSet[int]:
|
||||
return frozenset(self._block_ids_to_allocator.keys())
|
||||
|
||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||
"""Prefix cache hit rate. -1 means not supported or disabled."""
|
||||
assert device in self._allocators
|
||||
return self._allocators[device].get_prefix_cache_hit_rate()
|
||||
|
||||
def get_and_reset_swaps(self) -> List[Tuple[int, int]]:
|
||||
"""Returns and clears the mapping of source to destination block IDs.
|
||||
Will be called after every swapping operations for now, and after every
|
||||
|
@ -186,6 +186,11 @@ class BlockAllocator(ABC):
|
||||
num_lookahead_slots: int = 0) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_prefix_cache_hit_rate(self) -> float:
|
||||
"""Prefix cache hit rate. -1 means not supported or disabled."""
|
||||
pass
|
||||
|
||||
class NoFreeBlocksError(ValueError):
|
||||
pass
|
||||
|
||||
@ -278,3 +283,8 @@ class DeviceAwareBlockAllocator(ABC):
|
||||
There is at most one null block per allocator.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||
"""Prefix cache hit rate. -1 means not supported or disabled."""
|
||||
pass
|
||||
|
@ -341,6 +341,9 @@ class NaiveBlockAllocator(BlockAllocator):
|
||||
|
||||
block.block_id = block_id # Assign block_id
|
||||
|
||||
def get_prefix_cache_hit_rate(self) -> float:
|
||||
return -1
|
||||
|
||||
|
||||
class NaiveBlock(Block):
|
||||
"""An implementation of the Block class that does not support prefix
|
||||
|
@ -1,9 +1,8 @@
|
||||
"""Token blocks."""
|
||||
|
||||
from os.path import commonprefix
|
||||
from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple
|
||||
|
||||
from vllm.core.block.common import (CopyOnWriteTracker,
|
||||
from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker,
|
||||
get_all_blocks_recursively)
|
||||
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
|
||||
from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
|
||||
@ -107,6 +106,8 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
||||
self._cow_tracker = CopyOnWriteTracker(
|
||||
refcounter=self._refcounter.as_readonly())
|
||||
|
||||
self.metric_data = CacheMetricData()
|
||||
|
||||
# Implements Block.Factory.
|
||||
def _create_block(
|
||||
self,
|
||||
@ -155,9 +156,11 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
||||
|
||||
cached_block_id = self._cached_blocks.get(block.content_hash, None)
|
||||
if cached_block_id is not None:
|
||||
self.metric_data.query(hit=True)
|
||||
block.block_id = cached_block_id
|
||||
self._incr_refcount_cached_block(block)
|
||||
return block
|
||||
self.metric_data.query(hit=False)
|
||||
self._block_pool.free_block(block)
|
||||
|
||||
# No cached block => Allocate a new block
|
||||
@ -404,6 +407,9 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
||||
def all_block_ids(self) -> FrozenSet[int]:
|
||||
return self._hashless_allocator.all_block_ids
|
||||
|
||||
def get_prefix_cache_hit_rate(self) -> float:
|
||||
return self.metric_data.get_hit_rate()
|
||||
|
||||
def is_block_cached(self, block: Block) -> bool:
|
||||
assert block.content_hash is not None
|
||||
if block.content_hash in self._cached_blocks:
|
||||
|
@ -8,6 +8,7 @@ from typing import Sequence as GenericSequence
|
||||
from typing import Set, Tuple
|
||||
|
||||
from vllm.block import BlockTable, PhysicalTokenBlock
|
||||
from vllm.core.block.common import CacheMetricData
|
||||
from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec
|
||||
from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor
|
||||
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
|
||||
@ -60,6 +61,11 @@ class BlockAllocatorBase(ABC):
|
||||
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_prefix_cache_hit_rate(self) -> float:
|
||||
"""Prefix cache hit rate. -1 means not supported or disabled."""
|
||||
pass
|
||||
|
||||
|
||||
class CachedBlockAllocator(BlockAllocatorBase):
|
||||
"""Manages free physical token blocks for a device.
|
||||
@ -85,6 +91,8 @@ class CachedBlockAllocator(BlockAllocatorBase):
|
||||
|
||||
self.default_hash_ctr = count()
|
||||
|
||||
self.cache_metric_data = CacheMetricData()
|
||||
|
||||
def allocate_block(self, block_hash: int,
|
||||
num_hashed_tokens: int) -> PhysicalTokenBlock:
|
||||
if self.current_num_blocks == self.num_blocks:
|
||||
@ -105,15 +113,17 @@ class CachedBlockAllocator(BlockAllocatorBase):
|
||||
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
|
||||
if block_hash is None:
|
||||
block_hash = next(self.default_hash_ctr)
|
||||
|
||||
if block_hash in self.evictor:
|
||||
assert block_hash not in self.cached_blocks
|
||||
block = self.evictor.remove(block_hash)
|
||||
assert block.ref_count == 0
|
||||
self.cached_blocks[block_hash] = block
|
||||
block.ref_count += 1
|
||||
assert block.block_hash == block_hash
|
||||
return block
|
||||
if block_hash not in self.cached_blocks:
|
||||
|
||||
if block_hash in self.cached_blocks:
|
||||
self.cache_metric_data.query(hit=True)
|
||||
else:
|
||||
self.cache_metric_data.query(hit=False)
|
||||
self.cached_blocks[block_hash] = self.allocate_block(
|
||||
block_hash, num_hashed_tokens)
|
||||
block = self.cached_blocks[block_hash]
|
||||
@ -150,6 +160,9 @@ class CachedBlockAllocator(BlockAllocatorBase):
|
||||
del self.cached_blocks[old_hash]
|
||||
self.cached_blocks[block_hash] = block
|
||||
|
||||
def get_prefix_cache_hit_rate(self) -> float:
|
||||
return self.cache_metric_data.get_hit_rate()
|
||||
|
||||
|
||||
class UncachedBlockAllocator(BlockAllocatorBase):
|
||||
"""Manages free physical token blocks for a device.
|
||||
@ -209,6 +222,9 @@ class UncachedBlockAllocator(BlockAllocatorBase):
|
||||
raise NotImplementedError(
|
||||
"Invalid codepath for uncached block allocator.")
|
||||
|
||||
def get_prefix_cache_hit_rate(self) -> float:
|
||||
return -1
|
||||
|
||||
|
||||
class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
"""Manages the mapping between logical and physical token blocks."""
|
||||
@ -705,3 +721,10 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
if self.enable_caching:
|
||||
for seq in seq_group.get_seqs():
|
||||
self.compute_full_blocks_in_seq(seq)
|
||||
|
||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||
if device == Device.GPU:
|
||||
return self.gpu_allocator.get_prefix_cache_hit_rate()
|
||||
if device == Device.CPU:
|
||||
return self.cpu_allocator.get_prefix_cache_hit_rate()
|
||||
raise ValueError(f"Invalid device: {device}")
|
||||
|
@ -441,6 +441,9 @@ class BlockSpaceManagerV2(BlockSpaceManager):
|
||||
def get_num_free_cpu_blocks(self) -> int:
|
||||
return self.block_allocator.get_num_free_blocks(Device.CPU)
|
||||
|
||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||
return self.block_allocator.get_prefix_cache_hit_rate(device)
|
||||
|
||||
def _can_swap(self,
|
||||
seq_group: SequenceGroup,
|
||||
device: Device,
|
||||
|
@ -2,6 +2,7 @@ from typing import List, Tuple
|
||||
|
||||
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
|
||||
from vllm.sequence import Sequence, SequenceGroup
|
||||
from vllm.utils import Device
|
||||
|
||||
|
||||
class EmbeddingModelBlockSpaceManager(BlockSpaceManager):
|
||||
@ -81,3 +82,6 @@ class EmbeddingModelBlockSpaceManager(BlockSpaceManager):
|
||||
|
||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
|
||||
pass
|
||||
|
||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||
return -1
|
||||
|
@ -85,19 +85,21 @@ class LRUEvictor(Evictor):
|
||||
if len(self.free_table) == 0:
|
||||
raise ValueError("No usable cache memory left")
|
||||
|
||||
evicted_block = next(iter(self.free_table.values()))
|
||||
evicted_block_id = next(iter(self.free_table.keys()))
|
||||
evicted_block, evicted_block_id = None, None
|
||||
# The blocks with the lowest timestamps should be placed consecutively
|
||||
# at the start of OrderedDict. Loop through all these blocks to
|
||||
# find the one with maximum number of hashed tokens.
|
||||
for _id, block in self.free_table.items():
|
||||
if evicted_block is None:
|
||||
evicted_block, evicted_block_id = block, _id
|
||||
continue
|
||||
if evicted_block.last_accessed < block.last_accessed:
|
||||
break
|
||||
if (evicted_block.last_accessed == block.last_accessed and
|
||||
evicted_block.num_hashed_tokens < block.num_hashed_tokens):
|
||||
evicted_block = block
|
||||
evicted_block_id = _id
|
||||
if evicted_block.num_hashed_tokens < block.num_hashed_tokens:
|
||||
evicted_block, evicted_block_id = block, _id
|
||||
|
||||
assert evicted_block is not None
|
||||
assert evicted_block_id is not None
|
||||
self.free_table.pop(evicted_block_id)
|
||||
|
||||
return evicted_block_id, evicted_block.content_hash
|
||||
@ -110,7 +112,6 @@ class LRUEvictor(Evictor):
|
||||
|
||||
def update(self, block_id: int, last_accessed: float):
|
||||
self.free_table[block_id].last_accessed = last_accessed
|
||||
self.free_table.move_to_end(block_id)
|
||||
|
||||
def remove(self, block_id: int):
|
||||
if block_id not in self.free_table:
|
||||
|
@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence
|
||||
from typing import Tuple
|
||||
|
||||
from vllm.sequence import Sequence, SequenceGroup
|
||||
from vllm.utils import Device
|
||||
|
||||
|
||||
class AllocStatus(enum.Enum):
|
||||
@ -116,3 +117,8 @@ class BlockSpaceManager(ABC):
|
||||
@abstractmethod
|
||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||
"""Prefix cache hit rate. -1 means not supported or disabled."""
|
||||
pass
|
||||
|
@ -14,7 +14,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
||||
SequenceGroupMetadata, SequenceGroupMetadataDelta,
|
||||
SequenceStatus)
|
||||
from vllm.utils import PyObjectCache
|
||||
from vllm.utils import Device, PyObjectCache
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -447,6 +447,9 @@ class Scheduler:
|
||||
return len(self.waiting) != 0 or len(self.running) != 0 or len(
|
||||
self.swapped) != 0
|
||||
|
||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||
return self.block_manager.get_prefix_cache_hit_rate(device)
|
||||
|
||||
def get_num_unfinished_seq_groups(self) -> int:
|
||||
return len(self.waiting) + len(self.running) + len(self.swapped)
|
||||
|
||||
|
@ -47,7 +47,7 @@ from vllm.transformers_utils.tokenizer_group import (
|
||||
AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs)
|
||||
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||
usage_message)
|
||||
from vllm.utils import Counter
|
||||
from vllm.utils import Counter, Device
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -1390,6 +1390,13 @@ class LLMEngine:
|
||||
for scheduler in self.scheduler)
|
||||
cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
|
||||
|
||||
# Prefix Cache Hit Rate. Note that we always use
|
||||
# the cache hit rate of the first virtual engine.
|
||||
cpu_prefix_cache_hit_rate = self.scheduler[
|
||||
0].get_prefix_cache_hit_rate(Device.CPU)
|
||||
gpu_prefix_cache_hit_rate = self.scheduler[
|
||||
0].get_prefix_cache_hit_rate(Device.GPU)
|
||||
|
||||
# Iteration stats
|
||||
num_prompt_tokens_iter = 0
|
||||
num_generation_tokens_iter = 0
|
||||
@ -1498,6 +1505,9 @@ class LLMEngine:
|
||||
# KV Cache Usage in %
|
||||
gpu_cache_usage_sys=gpu_cache_usage_sys,
|
||||
cpu_cache_usage_sys=cpu_cache_usage_sys,
|
||||
# Prefix Cache Hit Rate
|
||||
cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
|
||||
gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
|
||||
|
||||
# Iteration stats
|
||||
num_prompt_tokens_iter=num_prompt_tokens_iter,
|
||||
|
@ -71,6 +71,17 @@ class Metrics:
|
||||
documentation="CPU KV-cache usage. 1 means 100 percent usage.",
|
||||
labelnames=labelnames,
|
||||
multiprocess_mode="sum")
|
||||
# Prefix caching block hit rate
|
||||
self.gauge_cpu_prefix_cache_hit_rate = self._gauge_cls(
|
||||
name="vllm:cpu_prefix_cache_hit_rate",
|
||||
documentation="CPU prefix cache block hit rate.",
|
||||
labelnames=labelnames,
|
||||
multiprocess_mode="sum")
|
||||
self.gauge_gpu_prefix_cache_hit_rate = self._gauge_cls(
|
||||
name="vllm:gpu_prefix_cache_hit_rate",
|
||||
documentation="GPU prefix cache block hit rate.",
|
||||
labelnames=labelnames,
|
||||
multiprocess_mode="sum")
|
||||
|
||||
# Iteration stats
|
||||
self.counter_num_preemption = self._counter_cls(
|
||||
@ -351,7 +362,13 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
stats.gpu_cache_usage_sys * 100,
|
||||
stats.cpu_cache_usage_sys * 100,
|
||||
)
|
||||
|
||||
if (stats.cpu_prefix_cache_hit_rate >= 0
|
||||
or stats.gpu_prefix_cache_hit_rate >= 0):
|
||||
logger.info(
|
||||
"Prefix cache hit rate: GPU: %.2f%%, CPU: %.2f%%",
|
||||
stats.gpu_prefix_cache_hit_rate * 100,
|
||||
stats.cpu_prefix_cache_hit_rate * 100,
|
||||
)
|
||||
if self.spec_decode_metrics is not None:
|
||||
logger.info(
|
||||
self._format_spec_decode_metrics_str(
|
||||
@ -423,6 +440,10 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
stats.gpu_cache_usage_sys)
|
||||
self._log_gauge(self.metrics.gauge_cpu_cache_usage,
|
||||
stats.cpu_cache_usage_sys)
|
||||
self._log_gauge(self.metrics.gauge_cpu_prefix_cache_hit_rate,
|
||||
stats.cpu_prefix_cache_hit_rate)
|
||||
self._log_gauge(self.metrics.gauge_gpu_prefix_cache_hit_rate,
|
||||
stats.gpu_prefix_cache_hit_rate)
|
||||
|
||||
# Iteration level data
|
||||
self._log_counter(self.metrics.counter_num_preemption,
|
||||
|
@ -32,6 +32,9 @@ class Stats:
|
||||
# KV Cache Usage in %
|
||||
gpu_cache_usage_sys: float
|
||||
cpu_cache_usage_sys: float
|
||||
# Prefix caching block hit rate
|
||||
cpu_prefix_cache_hit_rate: float
|
||||
gpu_prefix_cache_hit_rate: float
|
||||
|
||||
# Iteration stats (should have _iter suffix)
|
||||
num_prompt_tokens_iter: int
|
||||
|
Loading…
x
Reference in New Issue
Block a user