Support SHA256 as hash function in prefix caching (#15297)

Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com>
This commit is contained in:
marko 2025-03-26 19:11:28 +01:00 committed by GitHub
parent 35fad35a48
commit 27df5199d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 214 additions and 71 deletions

View File

@ -15,12 +15,13 @@ Block 3: |<------------------ prefix -------------------->| |<--- block tokens -
In the example above, the KV cache in the first block can be uniquely identified with the token “A gentle breeze stirred”. The third block can be uniquely identified with the tokens in the block “laughed in the distance”, along with the prefix tokens “A gentle breeze stirred the leaves as children”. Therefore, we can build the block hash of `hash(tuple[components])`, where components are:
* Parent hash value: The hash value of the parent hash block.
* Block tokens: A tuple of tokens in this block. The reason to include the exact tokens is to reduce potential hash value collision.
* Block tokens: A tuple of tokens in this block. The reason to include the exact tokens is to reduce potential hash value collision.
* Extra hashes: Other values required to make this block unique, such as LoRA IDs and multi-modality input hashes (see the example below).
Note 1: We only cache full blocks.
> **Note 1:** We only cache full blocks.
Note 2: The above hash key structure is not 100% collision free. Theoretically its still possible for the different prefix tokens to have the same hash value, but this should be nearly impossible to happen. Of course, contributions are welcome if you have an awesome idea to eliminate collusion entirely.
> **Note 2:** The above hash key structure is not 100% collision free. Theoretically its still possible for the different prefix tokens to have the same hash value. To avoid any hash collisions **in a multi-tenant setup, we advise to use SHA256** as hash function instead of the default builtin hash.
SHA256 is supported since vLLM v0.8.3 and must be enabled with a command line argument. It comes with a performance impact of about 100-200ns per token (~6ms for 50k tokens of context).
**A hashing example with multi-modality inputs**
In this example, we illustrate how prefix caching works with multi-modality inputs (e.g., images). Assuming we have a request with the following messages:

View File

@ -2,6 +2,8 @@
# ruff: noqa
import asyncio
import hashlib
import pickle
import socket
from collections.abc import AsyncIterator
from unittest.mock import patch
@ -14,7 +16,8 @@ from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.utils import (FlexibleArgumentParser, MemorySnapshot,
PlaceholderModule, StoreBoolean, bind_kv_cache,
deprecate_kwargs, get_open_port, memory_profiling,
merge_async_iterators, supports_kw, swap_dict_values)
merge_async_iterators, sha256, supports_kw,
swap_dict_values)
from .utils import create_new_process_for_each_test, error_on_warning
@ -476,3 +479,21 @@ def test_swap_dict_values(obj, key1, key2):
assert obj[key1] == original_obj[key2]
else:
assert key1 not in obj
@pytest.mark.parametrize("input", [(), ("abc", ), (None, ),
(None, bool, [1, 2, 3])])
@pytest.mark.parametrize("output", [0, 1, 2])
def test_sha256(input: tuple, output: int):
hash = sha256(input)
assert hash is not None
assert isinstance(hash, int)
assert hash != 0
bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
assert hash == int.from_bytes(hashlib.sha256(bytes).digest(), byteorder="big")
# hashing again, returns the same value
assert hash == sha256(input)
# hashing different input, returns different value
assert hash != sha256(input + (1, ))

View File

@ -5,8 +5,12 @@ import torch
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
KVCacheBlock, PrefixCachingMetrics,
from vllm.utils import sha256
# disable yapf here as it formats differently than isort such that both fail
# yapf: disable
from vllm.v1.core.kv_cache_utils import (NONE_HASH, BlockHashType,
FreeKVCacheBlockQueue, KVCacheBlock,
PrefixCachingMetrics,
generate_block_hash_extra_keys,
hash_block_tokens,
hash_request_tokens,
@ -16,6 +20,8 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request
# yapf: enable
def make_request(request_id,
prompt_token_ids,
@ -40,6 +46,12 @@ def make_request(request_id,
)
def test_none_hash():
assert NONE_HASH is not None
assert isinstance(NONE_HASH, int)
assert NONE_HASH != 0
def test_kv_cache_block():
# Test KVCacheBlock initialization
block = KVCacheBlock(block_id=0)
@ -190,21 +202,23 @@ def test_generate_block_hash_extra_keys_no_mm_inputs():
assert next_mm_idx == 0
def test_hash_block_tokens():
@pytest.mark.parametrize("hash_fn", [sha256, hash])
def test_hash_block_tokens(hash_fn):
parent_block_hash = 123
curr_block_token_ids = (1, 2, 3)
extra_keys = ("key1", "key2")
block_hash = hash_block_tokens(parent_block_hash, curr_block_token_ids,
extra_keys)
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
curr_block_token_ids, extra_keys)
assert isinstance(block_hash, BlockHashType)
assert block_hash.hash_value == hash(
assert block_hash.hash_value == hash_fn(
(parent_block_hash, curr_block_token_ids, extra_keys))
assert block_hash.token_ids == curr_block_token_ids
assert block_hash.extra_keys == extra_keys
def test_hash_request_tokens():
@pytest.mark.parametrize("hash_fn", [sha256, hash])
def test_hash_request_tokens(hash_fn):
request = make_request(
request_id=0,
prompt_token_ids=[_ for _ in range(6)],
@ -219,7 +233,7 @@ def test_hash_request_tokens():
)
block_size = 3
block_hashes = hash_request_tokens(block_size, request)
block_hashes = hash_request_tokens(hash_fn, block_size, request)
assert len(block_hashes) == 2
assert isinstance(block_hashes[0], BlockHashType)
@ -234,7 +248,8 @@ def test_hash_request_tokens():
assert block_hashes[1].extra_keys == ("hash2", )
def test_hash_tokens_different_mm_input():
@pytest.mark.parametrize("hash_fn", [sha256, hash])
def test_hash_tokens_different_mm_input(hash_fn):
request1 = make_request(
request_id=0,
prompt_token_ids=[_ for _ in range(6)],
@ -260,13 +275,14 @@ def test_hash_tokens_different_mm_input():
mm_hashes=["hash3", "hash2"],
)
block_size = 3
block_hashes1 = hash_request_tokens(block_size, request1)
block_hashes2 = hash_request_tokens(block_size, request2)
block_hashes1 = hash_request_tokens(hash_fn, block_size, request1)
block_hashes2 = hash_request_tokens(hash_fn, block_size, request2)
assert block_hashes1[0] != block_hashes2[0]
assert block_hashes1[1] != block_hashes2[1]
def test_hash_request_tokens_no_mm_inputs():
@pytest.mark.parametrize("hash_fn", [sha256, hash])
def test_hash_request_tokens_no_mm_inputs(hash_fn):
request = make_request(
request_id=0,
prompt_token_ids=[_ for _ in range(6)],
@ -275,7 +291,7 @@ def test_hash_request_tokens_no_mm_inputs():
)
block_size = 3
block_hashes = hash_request_tokens(block_size, request)
block_hashes = hash_request_tokens(hash_fn, block_size, request)
assert len(block_hashes) == 2
assert block_hashes[0].token_ids == (0, 1, 2)

View File

@ -7,7 +7,7 @@ import pytest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import cdiv
from vllm.utils import cdiv, sha256
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
@ -39,16 +39,21 @@ def make_request(request_id,
)
def test_prefill():
@pytest.mark.parametrize("hash_algo", ["sha256", "hash"])
def test_prefill(hash_algo):
manager = KVCacheManager(
block_size=16,
num_gpu_blocks=10,
max_model_len=8192,
sliding_window=None,
enable_caching=True,
caching_hash_algo=hash_algo,
num_preallocate_tokens=16,
)
# choose the hash function according to the parameter
hash_fn = sha256 if hash_algo == "sha256" else hash
# Complete 3 blocks (48 tokens)
common_token_ids = [i for i in range(3) for _ in range(16)]
@ -68,7 +73,8 @@ def test_prefill():
parent_block_hash = None
for block_id in (0, 1, 2):
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
block_tokens)
assert manager.block_pool.blocks[block_id].block_hash == block_hash
assert manager.block_pool.blocks[block_id].ref_cnt == 1
parent_block_hash = block_hash.hash_value
@ -163,6 +169,8 @@ def test_prefill_plp():
enable_caching=True,
num_preallocate_tokens=16,
)
# the default hash function is hash
hash_fn = hash
# Complete 3 blocks (48 tokens)
common_token_ids = [i for i in range(3) for _ in range(16)]
@ -185,7 +193,8 @@ def test_prefill_plp():
parent_block_hash = None
for block_id in (0, 1, 2):
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
block_tokens)
assert manager.block_pool.blocks[block_id].block_hash == block_hash
assert manager.block_pool.blocks[block_id].ref_cnt == 1
parent_block_hash = block_hash.hash_value
@ -522,7 +531,8 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
assert len(blocks) == 1 + num_preallocated_blocks
def test_cache_blocks():
@pytest.mark.parametrize("hash_fn", [sha256, hash])
def test_cache_blocks(hash_fn):
"""
This is a unit test that tests the correctness of the _cache_full_blocks
function of KVCacheManager.
@ -550,6 +560,7 @@ def test_cache_blocks():
num_cached_blocks=0,
num_full_blocks=2,
block_size=block_size,
hash_fn=hash_fn,
)
assert len(block_pool.cached_block_hash_to_block) == 2
@ -564,6 +575,7 @@ def test_cache_blocks():
num_cached_blocks=2,
num_full_blocks=3,
block_size=block_size,
hash_fn=hash_fn,
)
assert len(block_pool.cached_block_hash_to_block) == 3
assert blocks[0].block_hash is not None

View File

@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from argparse import ArgumentError
import pytest
from vllm import envs
@ -32,6 +34,24 @@ def test_prefix_caching_from_cli():
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
assert vllm_config.cache_config.enable_prefix_caching
# default hash algorithm is "builtin"
assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin"
# set hash algorithm to sha256
args = parser.parse_args(["--prefix-caching-hash-algo", "sha256"])
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256"
# set hash algorithm to builtin
args = parser.parse_args(["--prefix-caching-hash-algo", "builtin"])
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin"
# an invalid hash algorithm raises an error
parser.exit_on_error = False
with pytest.raises(ArgumentError):
args = parser.parse_args(["--prefix-caching-hash-algo", "invalid"])
def test_defaults_with_usage_context():
engine_args = EngineArgs(model="facebook/opt-125m")

View File

@ -1124,6 +1124,7 @@ class CacheConfig:
num_gpu_blocks_override: Optional[int] = None,
sliding_window: Optional[int] = None,
enable_prefix_caching: bool = False,
prefix_caching_hash_algo: str = "builtin",
cpu_offload_gb: float = 0,
calculate_kv_scales: Optional[bool] = None,
) -> None:
@ -1135,6 +1136,7 @@ class CacheConfig:
self.is_attention_free = is_attention_free
self.sliding_window = sliding_window
self.enable_prefix_caching = enable_prefix_caching
self.prefix_caching_hash_algo = prefix_caching_hash_algo
self.cpu_offload_gb = cpu_offload_gb
self.calculate_kv_scales = calculate_kv_scales
self._verify_args()
@ -1185,6 +1187,13 @@ class CacheConfig:
"Prefix caching is not supported with sliding window. "
"Run with --disable-sliding-window to use prefix caching.")
if self.enable_prefix_caching and self.prefix_caching_hash_algo not in (
"builtin", "sha256"):
raise ValueError(
"Unknown prefix caching hash algorithm: "
f"{self.prefix_caching_hash_algo}. Must be either "
"'builtin' or 'sha256'.")
def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",

View File

@ -118,6 +118,7 @@ class EngineArgs:
max_parallel_loading_workers: Optional[int] = None
block_size: Optional[int] = None
enable_prefix_caching: Optional[bool] = None
prefix_caching_hash_algo: str = "builtin"
disable_sliding_window: bool = False
disable_cascade_attn: bool = False
use_v2_block_manager: bool = True
@ -475,6 +476,16 @@ class EngineArgs:
help="Enables automatic prefix caching. "
"Use ``--no-enable-prefix-caching`` to disable explicitly.",
)
parser.add_argument(
"--prefix-caching-hash-algo",
type=str,
choices=["builtin", "sha256"],
default=EngineArgs.prefix_caching_hash_algo,
help="Set the hash algorithm for prefix caching. "
"Options are 'builtin' (Python's built-in hash) or 'sha256' "
"(collision resistant but with certain overheads). Defaults "
"to 'builtin'.",
)
parser.add_argument('--disable-sliding-window',
action='store_true',
help='Disables sliding window, '
@ -1329,6 +1340,7 @@ class EngineArgs:
num_gpu_blocks_override=self.num_gpu_blocks_override,
sliding_window=model_config.get_sliding_window(),
enable_prefix_caching=self.enable_prefix_caching,
prefix_caching_hash_algo=self.prefix_caching_hash_algo,
cpu_offload_gb=self.cpu_offload_gb,
calculate_kv_scales=self.calculate_kv_scales,
)
@ -1737,12 +1749,22 @@ class EngineArgs:
msg = "Chunked prefill is not supported for pooling models"
raise ValueError(msg)
# Disable prefix caching for multimodal models for VLLM_V0.
if (model_config.is_multimodal_model and self.enable_prefix_caching):
logger.warning(
"--enable-prefix-caching is not supported for multimodal "
"models in V0 and has been disabled.")
self.enable_prefix_caching = False
# if using prefix caching, we must set a hash algo
if self.enable_prefix_caching:
# Disable prefix caching for multimodal models for VLLM_V0.
if model_config.is_multimodal_model:
logger.warning(
"--enable-prefix-caching is not supported for multimodal "
"models in V0 and has been disabled.")
self.enable_prefix_caching = False
# VLLM_V0 only supports builtin hash algo for prefix caching.
if self.prefix_caching_hash_algo is None:
self.prefix_caching_hash_algo = "builtin"
elif self.prefix_caching_hash_algo == "sha256":
raise ValueError(
"sha256 is not supported for prefix caching in V0 engine. "
"Please use 'builtin'.")
# Set max_num_seqs to 256 for VLLM_V0.
if self.max_num_seqs is None:
@ -1758,6 +1780,10 @@ class EngineArgs:
if self.enable_prefix_caching is None:
self.enable_prefix_caching = True
# if using prefix caching, we must set a hash algo
if self.enable_prefix_caching and self.prefix_caching_hash_algo is None:
self.prefix_caching_hash_algo = "builtin"
# V1 should use the new scheduler by default.
# Swap it only if this arg is set to the original V0 default
if self.scheduler_cls == EngineArgs.scheduler_cls:

View File

@ -10,6 +10,7 @@ import datetime
import enum
import gc
import getpass
import hashlib
import importlib
import importlib.metadata
import importlib.util
@ -17,6 +18,7 @@ import inspect
import ipaddress
import multiprocessing
import os
import pickle
import re
import signal
import socket
@ -2442,3 +2444,21 @@ def cprofile(save_file: Optional[str] = None, enabled: bool = True):
return wrapper
return decorator
def sha256(input) -> int:
"""Hash any picklable Python object using SHA-256.
The input is serialized using pickle before hashing, which allows
arbitrary Python objects to be used. Note that this function does
not use a hash seedif you need one, prepend it explicitly to the input.
Args:
input: Any picklable Python object.
Returns:
An integer representing the SHA-256 hash of the serialized input.
"""
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
return int.from_bytes(hashlib.sha256(input_bytes).digest(),
byteorder="big")

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from collections import defaultdict
from collections.abc import Iterable
from typing import Optional
from typing import Callable, Optional
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
@ -15,10 +15,10 @@ logger = init_logger(__name__)
class BlockPool:
"""BlockPool that manages KVCacheBlocks.
It provides methods to allocate, free and cache the kv cache blocks. The
free_block_queue stores the free blocks in eviction order to enable
allocation, free, and cache eviction. The cached_block_hash_to_block
maps between block hash and cached block to support finding cached blocks
It provides methods to allocate, free and cache the kv cache blocks. The
free_block_queue stores the free blocks in eviction order to enable
allocation, free, and cache eviction. The cached_block_hash_to_block
maps between block hash and cached block to support finding cached blocks
by their block hash.
Args:
@ -75,11 +75,12 @@ class BlockPool:
num_cached_blocks: int,
num_full_blocks: int,
block_size: int,
hash_fn: Callable,
) -> None:
"""Cache a list of full blocks for prefix caching.
This function takes a list of blocks that will have their block hash
metadata to be updated and cached. Given a request, it computes the
block hashes for the blocks starting from `num_cached_blocks` to
block hashes for the blocks starting from `num_cached_blocks` to
`num_full_blocks`, updating the metadata for each block
and caching them in the `cached_block_hash_to_block`.
@ -87,12 +88,13 @@ class BlockPool:
request: The request to cache the blocks.
blocks: All blocks in the request.
block_hashes: Block hashes of the blocks in the request. Note that
this list may be shorter than the blocks list. In this case the
this list may be shorter than the blocks list. In this case the
missed block hash will be computed in this function.
num_cached_blocks: The number of blocks that are already cached.
num_full_blocks: The number of blocks that are full and should
num_full_blocks: The number of blocks that are full and should
be cached after this function.
block_size: Number of tokens in each block.
hash_fn: The hash function to use for block hashes.
"""
if num_cached_blocks == num_full_blocks:
return
@ -138,7 +140,7 @@ class BlockPool:
request, start_token_idx, end_token_idx, -1)
# Compute the hash of the current block.
block_hash = hash_block_tokens(prev_block_hash_value,
block_hash = hash_block_tokens(hash_fn, prev_block_hash_value,
block_tokens, extra_keys)
block_hashes.append(block_hash)

View File

@ -5,7 +5,7 @@ from collections.abc import Iterable
from typing import Optional
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.utils import cdiv, sha256
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
hash_request_tokens)
@ -24,6 +24,7 @@ class KVCacheManager:
max_model_len: int,
sliding_window: Optional[int] = None,
enable_caching: bool = True,
caching_hash_algo: str = "builtin",
num_preallocate_tokens: int = 64,
log_stats: bool = False,
) -> None:
@ -33,6 +34,7 @@ class KVCacheManager:
self.max_num_blocks_per_req = cdiv(max_model_len, block_size)
self.sliding_window = sliding_window
self.enable_caching = enable_caching
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
# FIXME: make prefix cache stats conditional on log_stats
self.log_stats = log_stats
# NOTE(woosuk): To avoid frequent block allocation, we preallocate some
@ -109,7 +111,8 @@ class KVCacheManager:
# if the scheduler has tried to schedule the request before.
block_hashes = self.req_to_block_hashes[request.request_id]
if not block_hashes:
block_hashes = hash_request_tokens(self.block_size, request)
block_hashes = hash_request_tokens(self.caching_hash_fn,
self.block_size, request)
self.req_to_block_hashes[request.request_id] = block_hashes
self.prefix_cache_stats.requests += 1
@ -247,6 +250,7 @@ class KVCacheManager:
num_cached_blocks=num_cached_blocks,
num_full_blocks=num_full_blocks_after_append,
block_size=self.block_size,
hash_fn=self.caching_hash_fn,
)
self.num_cached_block[

View File

@ -1,12 +1,14 @@
# SPDX-License-Identifier: Apache-2.0
"""KV-Cache Utilities."""
import os
from collections import deque
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, NamedTuple, Optional
from typing import Any, Callable, NamedTuple, Optional
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import sha256
from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheGroupSpec,
KVCacheSpec, KVCacheTensor)
from vllm.v1.metrics.stats import PrefixCacheStats
@ -18,9 +20,8 @@ logger = init_logger(__name__)
class BlockHashType(NamedTuple):
"""Hash value of a block (int), the token IDs in the block, and extra keys.
We keep a tuple of token IDs and extra keys to reduce the likelihood of
hash collisions when the hash value is the same. But please note that
hash collisions can still theoretically occur, albeit with an extremely
low probability.
hash collisions when the hash value is the same. By using SHA256 however,
hash collisions are practically impossible.
"""
# Hash value of the block in an integer.
hash_value: int
@ -30,6 +31,20 @@ class BlockHashType(NamedTuple):
extra_keys: Optional[Any] = None
# The hash seed for the first block of the prefix block sequence.
#
# Even if the hash function is the builtin hash(), we use sha256 to generate
# the initial hash to simplify the code. This is not performance critical
# as it is done one per process.
#
# We use a random value to avoid hash collisions or PYTHONHASHSEED environment
# variable if set such that processes can share the seed if needed.
# This aligns with the behavior of Python's hash() function, which also uses
# a random seed if PYTHONHASHSEED is not set.
NONE_HASH = int.from_bytes(os.urandom(32), byteorder="big") if os.getenv(
'PYTHONHASHSEED') is not None else sha256(os.getenv('PYTHONHASHSEED'))
class PrefixCachingMetrics:
"""Metrics for prefix caching with a hit rate of the most recent N requests.
@ -148,7 +163,7 @@ class FreeKVCacheBlockQueue:
builtin deque to support removing a block in the middle of the queue
in O(1) time. To close the performance gap to the builtin deque which is
implemented in C++, this class does not allocate any Python objects when
manipulating the linked list. Instead, this class manipulates the
manipulating the linked list. Instead, this class manipulates the
prev_free_block and next_free_block attributes of the given blocks.
The queue is ordered by block ID in the beginning. When a block is allocated
@ -178,7 +193,7 @@ class FreeKVCacheBlockQueue:
def popleft(self) -> KVCacheBlock:
"""Pop the first free block and reduce num_free_blocks by 1.
Returns:
The first free block.
"""
@ -191,7 +206,7 @@ class FreeKVCacheBlockQueue:
def remove(self, block: KVCacheBlock) -> None:
"""Remove a block in the free list and reduce num_free_blocks by 1.
Args:
block: The block to remove.
"""
@ -235,7 +250,7 @@ class FreeKVCacheBlockQueue:
def get_all_free_blocks(self) -> list[KVCacheBlock]:
"""Get all free blocks in the free list. Mainly used for testing.
Returns:
A list of free blocks.
"""
@ -251,10 +266,10 @@ def need_extra_keys(request: Request) -> bool:
"""Check whether the blocks allocated to this request need extra hash keys.
Args:
request (Request): The request.
request (Request): The request.
Returns:
bool: Whether blocks allocated to this request need extra hash keys.
bool: Whether blocks allocated to this request need extra hash keys.
"""
# Multimodal requests need to include the MM hash.
@ -269,13 +284,13 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
computation. For multi-modal inputs, the extra keys are
(mm_hash, start_offset) that indicate a mm input contained in the
block and its starting offset in the block tokens.
Args:
request: The request object.
start_token_idx: The start token index of the block.
end_token_idx: The end token index of the block.
start_mm_idx: The start multi-modal index of the block.
Returns:
A tuple of extra keys and the next multi-modal index.
"""
@ -333,10 +348,10 @@ def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
def _gen_lora_extra_hash_keys(request: Request) -> list[int]:
"""Generate extra keys related to LoRA for block hash computation.
Args:
request: The request object.
Returns:
Return LoRA id of the request if it is a LoRA request. Return empty
list otherwise.
@ -351,13 +366,13 @@ def generate_block_hash_extra_keys(
start_mm_idx: int) -> tuple[Optional[tuple[Any, ...]], int]:
"""Generate extra keys for the block hash. The extra keys can come from
the multi-modal inputs and request specific metadata (e.g., LoRA ID).
Args:
request: The request object.
start_token_idx: The start token index of the block.
end_token_idx: The end token index of the block.
start_mm_idx: The start multi-modal index of the block.
Returns:
A tuple of extra keys and the next multi-modal index.
"""
@ -375,6 +390,7 @@ def generate_block_hash_extra_keys(
def hash_block_tokens(
hash_function: Callable,
parent_block_hash: Optional[int],
curr_block_token_ids: Sequence[int],
extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHashType:
@ -395,21 +411,16 @@ def hash_block_tokens(
The entire tuple is used as the hash key of the block.
"""
if not parent_block_hash:
# Note that we use 'None' as a string here instead of None because
# as of Python 3.12, hash(None) returns a constant predictable value.
# This could possibly make it easier to find and exploit hash
# collisions. 'None' as a string will be hashed differently per process,
# but consistently within the same process. This is the same as the
# behavior of None prior to Python 3.12.
parent_block_hash = hash('None')
parent_block_hash = NONE_HASH
curr_block_token_ids_tuple = tuple(curr_block_token_ids)
return BlockHashType(
hash((parent_block_hash, curr_block_token_ids_tuple, extra_keys)),
hash_function(
(parent_block_hash, curr_block_token_ids_tuple, extra_keys)),
curr_block_token_ids_tuple, extra_keys)
def hash_request_tokens(block_size: int,
def hash_request_tokens(hash_function: Any, block_size: int,
request: Request) -> list[BlockHashType]:
"""Computes hash values of a chain of blocks given a sequence of
token IDs. The hash value is used for prefix caching.
@ -441,7 +452,7 @@ def hash_request_tokens(block_size: int,
req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
request, start, end, curr_mm_idx)
block_hash = hash_block_tokens(parent_block_hash_value,
block_hash = hash_block_tokens(hash_function, parent_block_hash_value,
block_token_ids, req_extra_keys)
ret.append(block_hash)
parent_block_hash_value = block_hash.hash_value
@ -452,7 +463,7 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int):
"""
Checks whether `available_memory` is enough for the KV cache to hold at
Checks whether `available_memory` is enough for the KV cache to hold at
least one request with the model's max_model_len.
Args:
@ -489,15 +500,15 @@ def create_kv_cache_group_specs(
grouped_layer_names: list[list[str]]) -> list[KVCacheGroupSpec]:
"""
Create KVCacheGroupSpec object for each kv cache group layer.
The layers in the same group should share the same
The layers in the same group should share the same
KVCacheSpec.
Args:
kv_cache_spec:
A mapping from each layer name to its corresponding KVCacheSpec.
grouped_layer_names:
A list of kv cache groups, where each element is a list of layer
names that belong to the same group and should share the same
A list of kv cache groups, where each element is a list of layer
names that belong to the same group and should share the same
KVCacheSpec.
Returns:
A list of KVCacheGroupSpec objects, one for each group.
@ -614,11 +625,11 @@ def get_kv_cache_config(vllm_config: VllmConfig,
def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]):
"""
Make the KV cache configurations for each worker consistent, so that all
Make the KV cache configurations for each worker consistent, so that all
workers can be controlled by the same KVCacheManager.
This function verifies that the layer group of each worker are the same,
and changes the num_blocks of each worker to the smallest among all workers.
Args:
kv_cache_configs: The KV cache configurations for each worker. Will be
in-place modified to make them consistent.

View File

@ -61,6 +61,7 @@ class Scheduler(SchedulerInterface):
max_model_len=self.max_model_len,
sliding_window=self.cache_config.sliding_window,
enable_caching=self.cache_config.enable_prefix_caching,
caching_hash_algo=self.cache_config.prefix_caching_hash_algo,
log_stats=self.log_stats)
self.block_size = self.cache_config.block_size