[MISC] Add prefix cache hit rate to metrics (#7606)

This commit is contained in:
Cody Yu 2024-08-19 11:52:07 -07:00 committed by GitHub
parent df845b2b46
commit 3ac50b47d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 200 additions and 16 deletions

View File

@ -682,6 +682,32 @@ class TestPrefixCachingBlockAllocator:
assert new_block[0].block_id == last_block_id assert new_block[0].block_id == last_block_id
# Test case for cache mertics
@staticmethod
def test_metric():
block_size = 16
allocator = PrefixCachingBlockAllocator(num_blocks=4,
block_size=block_size)
# Test when no query (0/0)
assert allocator.get_prefix_cache_hit_rate() == 0.0
token_ids = list(range(block_size))
allocator.allocate_immutable_block(prev_block=None,
token_ids=token_ids)
# Test 0/1 hit rate
assert allocator.get_prefix_cache_hit_rate() == 0.0
allocator.allocate_immutable_block(prev_block=None,
token_ids=token_ids)
# Test 1/2 hit rate
assert allocator.get_prefix_cache_hit_rate() == 0.5
# Test more than one block
for _ in range(2, 1005):
allocator.allocate_immutable_block(prev_block=None,
token_ids=token_ids)
assert allocator.get_prefix_cache_hit_rate() > 0.99
@staticmethod @staticmethod
def create_immutable_chain( def create_immutable_chain(
block_size: int, block_size: int,

View File

@ -34,6 +34,9 @@ def test_block_allocator(
assert (first_block == second_block) assert (first_block == second_block)
assert (second_block.ref_count == 2) assert (second_block.ref_count == 2)
# Check metric: 1 hit of 2 queries
assert block_allocator.get_prefix_cache_hit_rate() == 0.5
# Free the first_block and confirm that the ref_count is correctly # Free the first_block and confirm that the ref_count is correctly
# decremented on the second block # decremented on the second block
block_allocator.free(first_block) block_allocator.free(first_block)
@ -48,6 +51,10 @@ def test_block_allocator(
assert (first_block == second_block) assert (first_block == second_block)
assert (first_block.block_hash == block_hash) assert (first_block.block_hash == block_hash)
# Allocate one more time to get 3/4 hit rate for easy checking
block_allocator.allocate(block_hash, 0)
assert block_allocator.get_prefix_cache_hit_rate() == 0.75
@pytest.mark.parametrize("num_blocks", [16]) @pytest.mark.parametrize("num_blocks", [16])
def test_eviction(num_blocks: int, ): def test_eviction(num_blocks: int, ):

View File

@ -1,4 +1,5 @@
from collections import deque from collections import deque
from dataclasses import dataclass
from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple from typing import Deque, Dict, Iterable, List, Optional, Protocol, Tuple
from vllm.core.block.interfaces import Block, BlockAllocator from vllm.core.block.interfaces import Block, BlockAllocator
@ -282,6 +283,58 @@ class BlockList:
return self._block_ids return self._block_ids
@dataclass
class CacheMetricData:
"""A utility dataclass to maintain cache metric.
To avoid overflow, we maintain the hit rate in block granularity, so that
we can maintain a single hit rate for n_completed_block x block_size,
and calculate the real time hit rate by the following:
BS = The number of queries per block.
nB = The number of completed blocks.
HR = hit rate of (nB x BS) queries.
Q = current number of queries (< BS).
H = current number of hits (< BS).
hit rate = ((HR x nB) + (H / Q) x (Q / BS)) / (nB + Q / BS)
"""
num_completed_blocks: int = 0
completed_block_cache_hit_rate: float = 0.0
num_incompleted_block_queries: int = 0
num_incompleted_block_hit: int = 0
block_size: int = 1000
def query(self, hit: bool):
self.num_incompleted_block_queries += 1
self.num_incompleted_block_hit += 1 if hit else 0
# When a block is completed, update the cache hit rate
# and reset the incomplete numbers.
if self.num_incompleted_block_queries == self.block_size:
hit_rate = (self.num_incompleted_block_hit /
self.num_incompleted_block_queries)
self.completed_block_cache_hit_rate = (
self.completed_block_cache_hit_rate * self.num_completed_blocks
+ hit_rate) / (self.num_completed_blocks + 1)
self.num_incompleted_block_queries = 0
self.num_incompleted_block_hit = 0
self.num_completed_blocks += 1
def get_hit_rate(self):
incomplete_ratio = self.num_incompleted_block_queries / self.block_size
total_blocks = self.num_completed_blocks + incomplete_ratio
if total_blocks == 0:
return 0.0
completed_block_hit, incompleted_block_hit = 0.0, 0.0
if self.num_completed_blocks > 0:
completed_block_hit = (self.completed_block_cache_hit_rate *
self.num_completed_blocks)
if self.num_incompleted_block_queries > 0:
incompleted_hit_rate = (self.num_incompleted_block_hit /
self.num_incompleted_block_queries)
incompleted_block_hit = (incompleted_hit_rate * incomplete_ratio)
return (completed_block_hit + incompleted_block_hit) / total_blocks
def get_all_blocks_recursively(last_block: Block) -> List[Block]: def get_all_blocks_recursively(last_block: Block) -> List[Block]:
"""Retrieves all the blocks in a sequence starting from the last block. """Retrieves all the blocks in a sequence starting from the last block.

View File

@ -323,6 +323,11 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
def all_block_ids(self) -> FrozenSet[int]: def all_block_ids(self) -> FrozenSet[int]:
return frozenset(self._block_ids_to_allocator.keys()) return frozenset(self._block_ids_to_allocator.keys())
def get_prefix_cache_hit_rate(self, device: Device) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
assert device in self._allocators
return self._allocators[device].get_prefix_cache_hit_rate()
def get_and_reset_swaps(self) -> List[Tuple[int, int]]: def get_and_reset_swaps(self) -> List[Tuple[int, int]]:
"""Returns and clears the mapping of source to destination block IDs. """Returns and clears the mapping of source to destination block IDs.
Will be called after every swapping operations for now, and after every Will be called after every swapping operations for now, and after every

View File

@ -186,6 +186,11 @@ class BlockAllocator(ABC):
num_lookahead_slots: int = 0) -> int: num_lookahead_slots: int = 0) -> int:
pass pass
@abstractmethod
def get_prefix_cache_hit_rate(self) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass
class NoFreeBlocksError(ValueError): class NoFreeBlocksError(ValueError):
pass pass
@ -278,3 +283,8 @@ class DeviceAwareBlockAllocator(ABC):
There is at most one null block per allocator. There is at most one null block per allocator.
""" """
pass pass
@abstractmethod
def get_prefix_cache_hit_rate(self, device: Device) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass

View File

@ -341,6 +341,9 @@ class NaiveBlockAllocator(BlockAllocator):
block.block_id = block_id # Assign block_id block.block_id = block_id # Assign block_id
def get_prefix_cache_hit_rate(self) -> float:
return -1
class NaiveBlock(Block): class NaiveBlock(Block):
"""An implementation of the Block class that does not support prefix """An implementation of the Block class that does not support prefix

View File

@ -1,9 +1,8 @@
"""Token blocks.""" """Token blocks."""
from os.path import commonprefix from os.path import commonprefix
from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple
from vllm.core.block.common import (CopyOnWriteTracker, from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker,
get_all_blocks_recursively) get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
from vllm.core.block.naive_block import (BlockPool, NaiveBlock, from vllm.core.block.naive_block import (BlockPool, NaiveBlock,
@ -107,6 +106,8 @@ class PrefixCachingBlockAllocator(BlockAllocator):
self._cow_tracker = CopyOnWriteTracker( self._cow_tracker = CopyOnWriteTracker(
refcounter=self._refcounter.as_readonly()) refcounter=self._refcounter.as_readonly())
self.metric_data = CacheMetricData()
# Implements Block.Factory. # Implements Block.Factory.
def _create_block( def _create_block(
self, self,
@ -155,9 +156,11 @@ class PrefixCachingBlockAllocator(BlockAllocator):
cached_block_id = self._cached_blocks.get(block.content_hash, None) cached_block_id = self._cached_blocks.get(block.content_hash, None)
if cached_block_id is not None: if cached_block_id is not None:
self.metric_data.query(hit=True)
block.block_id = cached_block_id block.block_id = cached_block_id
self._incr_refcount_cached_block(block) self._incr_refcount_cached_block(block)
return block return block
self.metric_data.query(hit=False)
self._block_pool.free_block(block) self._block_pool.free_block(block)
# No cached block => Allocate a new block # No cached block => Allocate a new block
@ -404,6 +407,9 @@ class PrefixCachingBlockAllocator(BlockAllocator):
def all_block_ids(self) -> FrozenSet[int]: def all_block_ids(self) -> FrozenSet[int]:
return self._hashless_allocator.all_block_ids return self._hashless_allocator.all_block_ids
def get_prefix_cache_hit_rate(self) -> float:
return self.metric_data.get_hit_rate()
def is_block_cached(self, block: Block) -> bool: def is_block_cached(self, block: Block) -> bool:
assert block.content_hash is not None assert block.content_hash is not None
if block.content_hash in self._cached_blocks: if block.content_hash in self._cached_blocks:

View File

@ -8,6 +8,7 @@ from typing import Sequence as GenericSequence
from typing import Set, Tuple from typing import Set, Tuple
from vllm.block import BlockTable, PhysicalTokenBlock from vllm.block import BlockTable, PhysicalTokenBlock
from vllm.core.block.common import CacheMetricData
from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec
from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor
from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.core.interfaces import AllocStatus, BlockSpaceManager
@ -60,6 +61,11 @@ class BlockAllocatorBase(ABC):
def update_hash(self, block_hash: int, block: PhysicalTokenBlock): def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
pass pass
@abstractmethod
def get_prefix_cache_hit_rate(self) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass
class CachedBlockAllocator(BlockAllocatorBase): class CachedBlockAllocator(BlockAllocatorBase):
"""Manages free physical token blocks for a device. """Manages free physical token blocks for a device.
@ -85,6 +91,8 @@ class CachedBlockAllocator(BlockAllocatorBase):
self.default_hash_ctr = count() self.default_hash_ctr = count()
self.cache_metric_data = CacheMetricData()
def allocate_block(self, block_hash: int, def allocate_block(self, block_hash: int,
num_hashed_tokens: int) -> PhysicalTokenBlock: num_hashed_tokens: int) -> PhysicalTokenBlock:
if self.current_num_blocks == self.num_blocks: if self.current_num_blocks == self.num_blocks:
@ -105,15 +113,17 @@ class CachedBlockAllocator(BlockAllocatorBase):
num_hashed_tokens: int = 0) -> PhysicalTokenBlock: num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
if block_hash is None: if block_hash is None:
block_hash = next(self.default_hash_ctr) block_hash = next(self.default_hash_ctr)
if block_hash in self.evictor: if block_hash in self.evictor:
assert block_hash not in self.cached_blocks assert block_hash not in self.cached_blocks
block = self.evictor.remove(block_hash) block = self.evictor.remove(block_hash)
assert block.ref_count == 0 assert block.ref_count == 0
self.cached_blocks[block_hash] = block self.cached_blocks[block_hash] = block
block.ref_count += 1
assert block.block_hash == block_hash if block_hash in self.cached_blocks:
return block self.cache_metric_data.query(hit=True)
if block_hash not in self.cached_blocks: else:
self.cache_metric_data.query(hit=False)
self.cached_blocks[block_hash] = self.allocate_block( self.cached_blocks[block_hash] = self.allocate_block(
block_hash, num_hashed_tokens) block_hash, num_hashed_tokens)
block = self.cached_blocks[block_hash] block = self.cached_blocks[block_hash]
@ -150,6 +160,9 @@ class CachedBlockAllocator(BlockAllocatorBase):
del self.cached_blocks[old_hash] del self.cached_blocks[old_hash]
self.cached_blocks[block_hash] = block self.cached_blocks[block_hash] = block
def get_prefix_cache_hit_rate(self) -> float:
return self.cache_metric_data.get_hit_rate()
class UncachedBlockAllocator(BlockAllocatorBase): class UncachedBlockAllocator(BlockAllocatorBase):
"""Manages free physical token blocks for a device. """Manages free physical token blocks for a device.
@ -209,6 +222,9 @@ class UncachedBlockAllocator(BlockAllocatorBase):
raise NotImplementedError( raise NotImplementedError(
"Invalid codepath for uncached block allocator.") "Invalid codepath for uncached block allocator.")
def get_prefix_cache_hit_rate(self) -> float:
return -1
class BlockSpaceManagerV1(BlockSpaceManager): class BlockSpaceManagerV1(BlockSpaceManager):
"""Manages the mapping between logical and physical token blocks.""" """Manages the mapping between logical and physical token blocks."""
@ -705,3 +721,10 @@ class BlockSpaceManagerV1(BlockSpaceManager):
if self.enable_caching: if self.enable_caching:
for seq in seq_group.get_seqs(): for seq in seq_group.get_seqs():
self.compute_full_blocks_in_seq(seq) self.compute_full_blocks_in_seq(seq)
def get_prefix_cache_hit_rate(self, device: Device) -> float:
if device == Device.GPU:
return self.gpu_allocator.get_prefix_cache_hit_rate()
if device == Device.CPU:
return self.cpu_allocator.get_prefix_cache_hit_rate()
raise ValueError(f"Invalid device: {device}")

View File

@ -441,6 +441,9 @@ class BlockSpaceManagerV2(BlockSpaceManager):
def get_num_free_cpu_blocks(self) -> int: def get_num_free_cpu_blocks(self) -> int:
return self.block_allocator.get_num_free_blocks(Device.CPU) return self.block_allocator.get_num_free_blocks(Device.CPU)
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_allocator.get_prefix_cache_hit_rate(device)
def _can_swap(self, def _can_swap(self,
seq_group: SequenceGroup, seq_group: SequenceGroup,
device: Device, device: Device,

View File

@ -2,6 +2,7 @@ from typing import List, Tuple
from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.sequence import Sequence, SequenceGroup from vllm.sequence import Sequence, SequenceGroup
from vllm.utils import Device
class EmbeddingModelBlockSpaceManager(BlockSpaceManager): class EmbeddingModelBlockSpaceManager(BlockSpaceManager):
@ -81,3 +82,6 @@ class EmbeddingModelBlockSpaceManager(BlockSpaceManager):
def mark_blocks_as_computed(self, seq_group: SequenceGroup): def mark_blocks_as_computed(self, seq_group: SequenceGroup):
pass pass
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return -1

View File

@ -85,19 +85,21 @@ class LRUEvictor(Evictor):
if len(self.free_table) == 0: if len(self.free_table) == 0:
raise ValueError("No usable cache memory left") raise ValueError("No usable cache memory left")
evicted_block = next(iter(self.free_table.values())) evicted_block, evicted_block_id = None, None
evicted_block_id = next(iter(self.free_table.keys()))
# The blocks with the lowest timestamps should be placed consecutively # The blocks with the lowest timestamps should be placed consecutively
# at the start of OrderedDict. Loop through all these blocks to # at the start of OrderedDict. Loop through all these blocks to
# find the one with maximum number of hashed tokens. # find the one with maximum number of hashed tokens.
for _id, block in self.free_table.items(): for _id, block in self.free_table.items():
if evicted_block is None:
evicted_block, evicted_block_id = block, _id
continue
if evicted_block.last_accessed < block.last_accessed: if evicted_block.last_accessed < block.last_accessed:
break break
if (evicted_block.last_accessed == block.last_accessed and if evicted_block.num_hashed_tokens < block.num_hashed_tokens:
evicted_block.num_hashed_tokens < block.num_hashed_tokens): evicted_block, evicted_block_id = block, _id
evicted_block = block
evicted_block_id = _id
assert evicted_block is not None
assert evicted_block_id is not None
self.free_table.pop(evicted_block_id) self.free_table.pop(evicted_block_id)
return evicted_block_id, evicted_block.content_hash return evicted_block_id, evicted_block.content_hash
@ -110,7 +112,6 @@ class LRUEvictor(Evictor):
def update(self, block_id: int, last_accessed: float): def update(self, block_id: int, last_accessed: float):
self.free_table[block_id].last_accessed = last_accessed self.free_table[block_id].last_accessed = last_accessed
self.free_table.move_to_end(block_id)
def remove(self, block_id: int): def remove(self, block_id: int):
if block_id not in self.free_table: if block_id not in self.free_table:

View File

@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence
from typing import Tuple from typing import Tuple
from vllm.sequence import Sequence, SequenceGroup from vllm.sequence import Sequence, SequenceGroup
from vllm.utils import Device
class AllocStatus(enum.Enum): class AllocStatus(enum.Enum):
@ -116,3 +117,8 @@ class BlockSpaceManager(ABC):
@abstractmethod @abstractmethod
def mark_blocks_as_computed(self, seq_group: SequenceGroup): def mark_blocks_as_computed(self, seq_group: SequenceGroup):
pass pass
@abstractmethod
def get_prefix_cache_hit_rate(self, device: Device) -> float:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass

View File

@ -14,7 +14,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup, from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceGroupMetadataDelta, SequenceGroupMetadata, SequenceGroupMetadataDelta,
SequenceStatus) SequenceStatus)
from vllm.utils import PyObjectCache from vllm.utils import Device, PyObjectCache
logger = init_logger(__name__) logger = init_logger(__name__)
@ -447,6 +447,9 @@ class Scheduler:
return len(self.waiting) != 0 or len(self.running) != 0 or len( return len(self.waiting) != 0 or len(self.running) != 0 or len(
self.swapped) != 0 self.swapped) != 0
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_manager.get_prefix_cache_hit_rate(device)
def get_num_unfinished_seq_groups(self) -> int: def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped) return len(self.waiting) + len(self.running) + len(self.swapped)

View File

@ -47,7 +47,7 @@ from vllm.transformers_utils.tokenizer_group import (
AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs) AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message) usage_message)
from vllm.utils import Counter from vllm.utils import Counter, Device
from vllm.version import __version__ as VLLM_VERSION from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__) logger = init_logger(__name__)
@ -1390,6 +1390,13 @@ class LLMEngine:
for scheduler in self.scheduler) for scheduler in self.scheduler)
cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu) cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu)
# Prefix Cache Hit Rate. Note that we always use
# the cache hit rate of the first virtual engine.
cpu_prefix_cache_hit_rate = self.scheduler[
0].get_prefix_cache_hit_rate(Device.CPU)
gpu_prefix_cache_hit_rate = self.scheduler[
0].get_prefix_cache_hit_rate(Device.GPU)
# Iteration stats # Iteration stats
num_prompt_tokens_iter = 0 num_prompt_tokens_iter = 0
num_generation_tokens_iter = 0 num_generation_tokens_iter = 0
@ -1498,6 +1505,9 @@ class LLMEngine:
# KV Cache Usage in % # KV Cache Usage in %
gpu_cache_usage_sys=gpu_cache_usage_sys, gpu_cache_usage_sys=gpu_cache_usage_sys,
cpu_cache_usage_sys=cpu_cache_usage_sys, cpu_cache_usage_sys=cpu_cache_usage_sys,
# Prefix Cache Hit Rate
cpu_prefix_cache_hit_rate=cpu_prefix_cache_hit_rate,
gpu_prefix_cache_hit_rate=gpu_prefix_cache_hit_rate,
# Iteration stats # Iteration stats
num_prompt_tokens_iter=num_prompt_tokens_iter, num_prompt_tokens_iter=num_prompt_tokens_iter,

View File

@ -71,6 +71,17 @@ class Metrics:
documentation="CPU KV-cache usage. 1 means 100 percent usage.", documentation="CPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames, labelnames=labelnames,
multiprocess_mode="sum") multiprocess_mode="sum")
# Prefix caching block hit rate
self.gauge_cpu_prefix_cache_hit_rate = self._gauge_cls(
name="vllm:cpu_prefix_cache_hit_rate",
documentation="CPU prefix cache block hit rate.",
labelnames=labelnames,
multiprocess_mode="sum")
self.gauge_gpu_prefix_cache_hit_rate = self._gauge_cls(
name="vllm:gpu_prefix_cache_hit_rate",
documentation="GPU prefix cache block hit rate.",
labelnames=labelnames,
multiprocess_mode="sum")
# Iteration stats # Iteration stats
self.counter_num_preemption = self._counter_cls( self.counter_num_preemption = self._counter_cls(
@ -351,7 +362,13 @@ class LoggingStatLogger(StatLoggerBase):
stats.gpu_cache_usage_sys * 100, stats.gpu_cache_usage_sys * 100,
stats.cpu_cache_usage_sys * 100, stats.cpu_cache_usage_sys * 100,
) )
if (stats.cpu_prefix_cache_hit_rate >= 0
or stats.gpu_prefix_cache_hit_rate >= 0):
logger.info(
"Prefix cache hit rate: GPU: %.2f%%, CPU: %.2f%%",
stats.gpu_prefix_cache_hit_rate * 100,
stats.cpu_prefix_cache_hit_rate * 100,
)
if self.spec_decode_metrics is not None: if self.spec_decode_metrics is not None:
logger.info( logger.info(
self._format_spec_decode_metrics_str( self._format_spec_decode_metrics_str(
@ -423,6 +440,10 @@ class PrometheusStatLogger(StatLoggerBase):
stats.gpu_cache_usage_sys) stats.gpu_cache_usage_sys)
self._log_gauge(self.metrics.gauge_cpu_cache_usage, self._log_gauge(self.metrics.gauge_cpu_cache_usage,
stats.cpu_cache_usage_sys) stats.cpu_cache_usage_sys)
self._log_gauge(self.metrics.gauge_cpu_prefix_cache_hit_rate,
stats.cpu_prefix_cache_hit_rate)
self._log_gauge(self.metrics.gauge_gpu_prefix_cache_hit_rate,
stats.gpu_prefix_cache_hit_rate)
# Iteration level data # Iteration level data
self._log_counter(self.metrics.counter_num_preemption, self._log_counter(self.metrics.counter_num_preemption,

View File

@ -32,6 +32,9 @@ class Stats:
# KV Cache Usage in % # KV Cache Usage in %
gpu_cache_usage_sys: float gpu_cache_usage_sys: float
cpu_cache_usage_sys: float cpu_cache_usage_sys: float
# Prefix caching block hit rate
cpu_prefix_cache_hit_rate: float
gpu_prefix_cache_hit_rate: float
# Iteration stats (should have _iter suffix) # Iteration stats (should have _iter suffix)
num_prompt_tokens_iter: int num_prompt_tokens_iter: int