Support SHA256 as hash function in prefix caching (#15297)
Signed-off-by: Marko Rosenmueller <5467316+dr75@users.noreply.github.com>
This commit is contained in:
parent
35fad35a48
commit
27df5199d9
@ -18,9 +18,10 @@ In the example above, the KV cache in the first block can be uniquely identified
|
||||
* Block tokens: A tuple of tokens in this block. The reason to include the exact tokens is to reduce potential hash value collision.
|
||||
* Extra hashes: Other values required to make this block unique, such as LoRA IDs and multi-modality input hashes (see the example below).
|
||||
|
||||
Note 1: We only cache full blocks.
|
||||
> **Note 1:** We only cache full blocks.
|
||||
|
||||
Note 2: The above hash key structure is not 100% collision free. Theoretically it’s still possible for the different prefix tokens to have the same hash value, but this should be nearly impossible to happen. Of course, contributions are welcome if you have an awesome idea to eliminate collusion entirely.
|
||||
> **Note 2:** The above hash key structure is not 100% collision free. Theoretically it’s still possible for the different prefix tokens to have the same hash value. To avoid any hash collisions **in a multi-tenant setup, we advise to use SHA256** as hash function instead of the default builtin hash.
|
||||
SHA256 is supported since vLLM v0.8.3 and must be enabled with a command line argument. It comes with a performance impact of about 100-200ns per token (~6ms for 50k tokens of context).
|
||||
|
||||
**A hashing example with multi-modality inputs**
|
||||
In this example, we illustrate how prefix caching works with multi-modality inputs (e.g., images). Assuming we have a request with the following messages:
|
||||
|
@ -2,6 +2,8 @@
|
||||
# ruff: noqa
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import pickle
|
||||
import socket
|
||||
from collections.abc import AsyncIterator
|
||||
from unittest.mock import patch
|
||||
@ -14,7 +16,8 @@ from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.utils import (FlexibleArgumentParser, MemorySnapshot,
|
||||
PlaceholderModule, StoreBoolean, bind_kv_cache,
|
||||
deprecate_kwargs, get_open_port, memory_profiling,
|
||||
merge_async_iterators, supports_kw, swap_dict_values)
|
||||
merge_async_iterators, sha256, supports_kw,
|
||||
swap_dict_values)
|
||||
|
||||
from .utils import create_new_process_for_each_test, error_on_warning
|
||||
|
||||
@ -476,3 +479,21 @@ def test_swap_dict_values(obj, key1, key2):
|
||||
assert obj[key1] == original_obj[key2]
|
||||
else:
|
||||
assert key1 not in obj
|
||||
|
||||
@pytest.mark.parametrize("input", [(), ("abc", ), (None, ),
|
||||
(None, bool, [1, 2, 3])])
|
||||
@pytest.mark.parametrize("output", [0, 1, 2])
|
||||
def test_sha256(input: tuple, output: int):
|
||||
hash = sha256(input)
|
||||
assert hash is not None
|
||||
assert isinstance(hash, int)
|
||||
assert hash != 0
|
||||
|
||||
bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
assert hash == int.from_bytes(hashlib.sha256(bytes).digest(), byteorder="big")
|
||||
|
||||
# hashing again, returns the same value
|
||||
assert hash == sha256(input)
|
||||
|
||||
# hashing different input, returns different value
|
||||
assert hash != sha256(input + (1, ))
|
||||
|
@ -5,8 +5,12 @@ import torch
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalKwargs
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
|
||||
KVCacheBlock, PrefixCachingMetrics,
|
||||
from vllm.utils import sha256
|
||||
# disable yapf here as it formats differently than isort such that both fail
|
||||
# yapf: disable
|
||||
from vllm.v1.core.kv_cache_utils import (NONE_HASH, BlockHashType,
|
||||
FreeKVCacheBlockQueue, KVCacheBlock,
|
||||
PrefixCachingMetrics,
|
||||
generate_block_hash_extra_keys,
|
||||
hash_block_tokens,
|
||||
hash_request_tokens,
|
||||
@ -16,6 +20,8 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request
|
||||
|
||||
# yapf: enable
|
||||
|
||||
|
||||
def make_request(request_id,
|
||||
prompt_token_ids,
|
||||
@ -40,6 +46,12 @@ def make_request(request_id,
|
||||
)
|
||||
|
||||
|
||||
def test_none_hash():
|
||||
assert NONE_HASH is not None
|
||||
assert isinstance(NONE_HASH, int)
|
||||
assert NONE_HASH != 0
|
||||
|
||||
|
||||
def test_kv_cache_block():
|
||||
# Test KVCacheBlock initialization
|
||||
block = KVCacheBlock(block_id=0)
|
||||
@ -190,21 +202,23 @@ def test_generate_block_hash_extra_keys_no_mm_inputs():
|
||||
assert next_mm_idx == 0
|
||||
|
||||
|
||||
def test_hash_block_tokens():
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, hash])
|
||||
def test_hash_block_tokens(hash_fn):
|
||||
parent_block_hash = 123
|
||||
curr_block_token_ids = (1, 2, 3)
|
||||
extra_keys = ("key1", "key2")
|
||||
|
||||
block_hash = hash_block_tokens(parent_block_hash, curr_block_token_ids,
|
||||
extra_keys)
|
||||
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
||||
curr_block_token_ids, extra_keys)
|
||||
assert isinstance(block_hash, BlockHashType)
|
||||
assert block_hash.hash_value == hash(
|
||||
assert block_hash.hash_value == hash_fn(
|
||||
(parent_block_hash, curr_block_token_ids, extra_keys))
|
||||
assert block_hash.token_ids == curr_block_token_ids
|
||||
assert block_hash.extra_keys == extra_keys
|
||||
|
||||
|
||||
def test_hash_request_tokens():
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, hash])
|
||||
def test_hash_request_tokens(hash_fn):
|
||||
request = make_request(
|
||||
request_id=0,
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
@ -219,7 +233,7 @@ def test_hash_request_tokens():
|
||||
)
|
||||
|
||||
block_size = 3
|
||||
block_hashes = hash_request_tokens(block_size, request)
|
||||
block_hashes = hash_request_tokens(hash_fn, block_size, request)
|
||||
|
||||
assert len(block_hashes) == 2
|
||||
assert isinstance(block_hashes[0], BlockHashType)
|
||||
@ -234,7 +248,8 @@ def test_hash_request_tokens():
|
||||
assert block_hashes[1].extra_keys == ("hash2", )
|
||||
|
||||
|
||||
def test_hash_tokens_different_mm_input():
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, hash])
|
||||
def test_hash_tokens_different_mm_input(hash_fn):
|
||||
request1 = make_request(
|
||||
request_id=0,
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
@ -260,13 +275,14 @@ def test_hash_tokens_different_mm_input():
|
||||
mm_hashes=["hash3", "hash2"],
|
||||
)
|
||||
block_size = 3
|
||||
block_hashes1 = hash_request_tokens(block_size, request1)
|
||||
block_hashes2 = hash_request_tokens(block_size, request2)
|
||||
block_hashes1 = hash_request_tokens(hash_fn, block_size, request1)
|
||||
block_hashes2 = hash_request_tokens(hash_fn, block_size, request2)
|
||||
assert block_hashes1[0] != block_hashes2[0]
|
||||
assert block_hashes1[1] != block_hashes2[1]
|
||||
|
||||
|
||||
def test_hash_request_tokens_no_mm_inputs():
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, hash])
|
||||
def test_hash_request_tokens_no_mm_inputs(hash_fn):
|
||||
request = make_request(
|
||||
request_id=0,
|
||||
prompt_token_ids=[_ for _ in range(6)],
|
||||
@ -275,7 +291,7 @@ def test_hash_request_tokens_no_mm_inputs():
|
||||
)
|
||||
|
||||
block_size = 3
|
||||
block_hashes = hash_request_tokens(block_size, request)
|
||||
block_hashes = hash_request_tokens(hash_fn, block_size, request)
|
||||
|
||||
assert len(block_hashes) == 2
|
||||
assert block_hashes[0].token_ids == (0, 1, 2)
|
||||
|
@ -7,7 +7,7 @@ import pytest
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import cdiv
|
||||
from vllm.utils import cdiv, sha256
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
|
||||
@ -39,16 +39,21 @@ def make_request(request_id,
|
||||
)
|
||||
|
||||
|
||||
def test_prefill():
|
||||
@pytest.mark.parametrize("hash_algo", ["sha256", "hash"])
|
||||
def test_prefill(hash_algo):
|
||||
manager = KVCacheManager(
|
||||
block_size=16,
|
||||
num_gpu_blocks=10,
|
||||
max_model_len=8192,
|
||||
sliding_window=None,
|
||||
enable_caching=True,
|
||||
caching_hash_algo=hash_algo,
|
||||
num_preallocate_tokens=16,
|
||||
)
|
||||
|
||||
# choose the hash function according to the parameter
|
||||
hash_fn = sha256 if hash_algo == "sha256" else hash
|
||||
|
||||
# Complete 3 blocks (48 tokens)
|
||||
common_token_ids = [i for i in range(3) for _ in range(16)]
|
||||
|
||||
@ -68,7 +73,8 @@ def test_prefill():
|
||||
parent_block_hash = None
|
||||
for block_id in (0, 1, 2):
|
||||
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
|
||||
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
|
||||
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
||||
block_tokens)
|
||||
assert manager.block_pool.blocks[block_id].block_hash == block_hash
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
parent_block_hash = block_hash.hash_value
|
||||
@ -163,6 +169,8 @@ def test_prefill_plp():
|
||||
enable_caching=True,
|
||||
num_preallocate_tokens=16,
|
||||
)
|
||||
# the default hash function is hash
|
||||
hash_fn = hash
|
||||
|
||||
# Complete 3 blocks (48 tokens)
|
||||
common_token_ids = [i for i in range(3) for _ in range(16)]
|
||||
@ -185,7 +193,8 @@ def test_prefill_plp():
|
||||
parent_block_hash = None
|
||||
for block_id in (0, 1, 2):
|
||||
block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16])
|
||||
block_hash = hash_block_tokens(parent_block_hash, block_tokens)
|
||||
block_hash = hash_block_tokens(hash_fn, parent_block_hash,
|
||||
block_tokens)
|
||||
assert manager.block_pool.blocks[block_id].block_hash == block_hash
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
parent_block_hash = block_hash.hash_value
|
||||
@ -522,7 +531,8 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
|
||||
assert len(blocks) == 1 + num_preallocated_blocks
|
||||
|
||||
|
||||
def test_cache_blocks():
|
||||
@pytest.mark.parametrize("hash_fn", [sha256, hash])
|
||||
def test_cache_blocks(hash_fn):
|
||||
"""
|
||||
This is a unit test that tests the correctness of the _cache_full_blocks
|
||||
function of KVCacheManager.
|
||||
@ -550,6 +560,7 @@ def test_cache_blocks():
|
||||
num_cached_blocks=0,
|
||||
num_full_blocks=2,
|
||||
block_size=block_size,
|
||||
hash_fn=hash_fn,
|
||||
)
|
||||
|
||||
assert len(block_pool.cached_block_hash_to_block) == 2
|
||||
@ -564,6 +575,7 @@ def test_cache_blocks():
|
||||
num_cached_blocks=2,
|
||||
num_full_blocks=3,
|
||||
block_size=block_size,
|
||||
hash_fn=hash_fn,
|
||||
)
|
||||
assert len(block_pool.cached_block_hash_to_block) == 3
|
||||
assert blocks[0].block_hash is not None
|
||||
|
@ -1,5 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from argparse import ArgumentError
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import envs
|
||||
@ -32,6 +34,24 @@ def test_prefix_caching_from_cli():
|
||||
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
||||
assert vllm_config.cache_config.enable_prefix_caching
|
||||
|
||||
# default hash algorithm is "builtin"
|
||||
assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin"
|
||||
|
||||
# set hash algorithm to sha256
|
||||
args = parser.parse_args(["--prefix-caching-hash-algo", "sha256"])
|
||||
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
||||
assert vllm_config.cache_config.prefix_caching_hash_algo == "sha256"
|
||||
|
||||
# set hash algorithm to builtin
|
||||
args = parser.parse_args(["--prefix-caching-hash-algo", "builtin"])
|
||||
vllm_config = EngineArgs.from_cli_args(args=args).create_engine_config()
|
||||
assert vllm_config.cache_config.prefix_caching_hash_algo == "builtin"
|
||||
|
||||
# an invalid hash algorithm raises an error
|
||||
parser.exit_on_error = False
|
||||
with pytest.raises(ArgumentError):
|
||||
args = parser.parse_args(["--prefix-caching-hash-algo", "invalid"])
|
||||
|
||||
|
||||
def test_defaults_with_usage_context():
|
||||
engine_args = EngineArgs(model="facebook/opt-125m")
|
||||
|
@ -1124,6 +1124,7 @@ class CacheConfig:
|
||||
num_gpu_blocks_override: Optional[int] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
enable_prefix_caching: bool = False,
|
||||
prefix_caching_hash_algo: str = "builtin",
|
||||
cpu_offload_gb: float = 0,
|
||||
calculate_kv_scales: Optional[bool] = None,
|
||||
) -> None:
|
||||
@ -1135,6 +1136,7 @@ class CacheConfig:
|
||||
self.is_attention_free = is_attention_free
|
||||
self.sliding_window = sliding_window
|
||||
self.enable_prefix_caching = enable_prefix_caching
|
||||
self.prefix_caching_hash_algo = prefix_caching_hash_algo
|
||||
self.cpu_offload_gb = cpu_offload_gb
|
||||
self.calculate_kv_scales = calculate_kv_scales
|
||||
self._verify_args()
|
||||
@ -1185,6 +1187,13 @@ class CacheConfig:
|
||||
"Prefix caching is not supported with sliding window. "
|
||||
"Run with --disable-sliding-window to use prefix caching.")
|
||||
|
||||
if self.enable_prefix_caching and self.prefix_caching_hash_algo not in (
|
||||
"builtin", "sha256"):
|
||||
raise ValueError(
|
||||
"Unknown prefix caching hash algorithm: "
|
||||
f"{self.prefix_caching_hash_algo}. Must be either "
|
||||
"'builtin' or 'sha256'.")
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
parallel_config: "ParallelConfig",
|
||||
|
@ -118,6 +118,7 @@ class EngineArgs:
|
||||
max_parallel_loading_workers: Optional[int] = None
|
||||
block_size: Optional[int] = None
|
||||
enable_prefix_caching: Optional[bool] = None
|
||||
prefix_caching_hash_algo: str = "builtin"
|
||||
disable_sliding_window: bool = False
|
||||
disable_cascade_attn: bool = False
|
||||
use_v2_block_manager: bool = True
|
||||
@ -475,6 +476,16 @@ class EngineArgs:
|
||||
help="Enables automatic prefix caching. "
|
||||
"Use ``--no-enable-prefix-caching`` to disable explicitly.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefix-caching-hash-algo",
|
||||
type=str,
|
||||
choices=["builtin", "sha256"],
|
||||
default=EngineArgs.prefix_caching_hash_algo,
|
||||
help="Set the hash algorithm for prefix caching. "
|
||||
"Options are 'builtin' (Python's built-in hash) or 'sha256' "
|
||||
"(collision resistant but with certain overheads). Defaults "
|
||||
"to 'builtin'.",
|
||||
)
|
||||
parser.add_argument('--disable-sliding-window',
|
||||
action='store_true',
|
||||
help='Disables sliding window, '
|
||||
@ -1329,6 +1340,7 @@ class EngineArgs:
|
||||
num_gpu_blocks_override=self.num_gpu_blocks_override,
|
||||
sliding_window=model_config.get_sliding_window(),
|
||||
enable_prefix_caching=self.enable_prefix_caching,
|
||||
prefix_caching_hash_algo=self.prefix_caching_hash_algo,
|
||||
cpu_offload_gb=self.cpu_offload_gb,
|
||||
calculate_kv_scales=self.calculate_kv_scales,
|
||||
)
|
||||
@ -1737,12 +1749,22 @@ class EngineArgs:
|
||||
msg = "Chunked prefill is not supported for pooling models"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Disable prefix caching for multimodal models for VLLM_V0.
|
||||
if (model_config.is_multimodal_model and self.enable_prefix_caching):
|
||||
logger.warning(
|
||||
"--enable-prefix-caching is not supported for multimodal "
|
||||
"models in V0 and has been disabled.")
|
||||
self.enable_prefix_caching = False
|
||||
# if using prefix caching, we must set a hash algo
|
||||
if self.enable_prefix_caching:
|
||||
# Disable prefix caching for multimodal models for VLLM_V0.
|
||||
if model_config.is_multimodal_model:
|
||||
logger.warning(
|
||||
"--enable-prefix-caching is not supported for multimodal "
|
||||
"models in V0 and has been disabled.")
|
||||
self.enable_prefix_caching = False
|
||||
|
||||
# VLLM_V0 only supports builtin hash algo for prefix caching.
|
||||
if self.prefix_caching_hash_algo is None:
|
||||
self.prefix_caching_hash_algo = "builtin"
|
||||
elif self.prefix_caching_hash_algo == "sha256":
|
||||
raise ValueError(
|
||||
"sha256 is not supported for prefix caching in V0 engine. "
|
||||
"Please use 'builtin'.")
|
||||
|
||||
# Set max_num_seqs to 256 for VLLM_V0.
|
||||
if self.max_num_seqs is None:
|
||||
@ -1758,6 +1780,10 @@ class EngineArgs:
|
||||
if self.enable_prefix_caching is None:
|
||||
self.enable_prefix_caching = True
|
||||
|
||||
# if using prefix caching, we must set a hash algo
|
||||
if self.enable_prefix_caching and self.prefix_caching_hash_algo is None:
|
||||
self.prefix_caching_hash_algo = "builtin"
|
||||
|
||||
# V1 should use the new scheduler by default.
|
||||
# Swap it only if this arg is set to the original V0 default
|
||||
if self.scheduler_cls == EngineArgs.scheduler_cls:
|
||||
|
@ -10,6 +10,7 @@ import datetime
|
||||
import enum
|
||||
import gc
|
||||
import getpass
|
||||
import hashlib
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
import importlib.util
|
||||
@ -17,6 +18,7 @@ import inspect
|
||||
import ipaddress
|
||||
import multiprocessing
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import signal
|
||||
import socket
|
||||
@ -2442,3 +2444,21 @@ def cprofile(save_file: Optional[str] = None, enabled: bool = True):
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def sha256(input) -> int:
|
||||
"""Hash any picklable Python object using SHA-256.
|
||||
|
||||
The input is serialized using pickle before hashing, which allows
|
||||
arbitrary Python objects to be used. Note that this function does
|
||||
not use a hash seed—if you need one, prepend it explicitly to the input.
|
||||
|
||||
Args:
|
||||
input: Any picklable Python object.
|
||||
|
||||
Returns:
|
||||
An integer representing the SHA-256 hash of the serialized input.
|
||||
"""
|
||||
input_bytes = pickle.dumps(input, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
return int.from_bytes(hashlib.sha256(input_bytes).digest(),
|
||||
byteorder="big")
|
||||
|
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
|
||||
@ -75,6 +75,7 @@ class BlockPool:
|
||||
num_cached_blocks: int,
|
||||
num_full_blocks: int,
|
||||
block_size: int,
|
||||
hash_fn: Callable,
|
||||
) -> None:
|
||||
"""Cache a list of full blocks for prefix caching.
|
||||
This function takes a list of blocks that will have their block hash
|
||||
@ -93,6 +94,7 @@ class BlockPool:
|
||||
num_full_blocks: The number of blocks that are full and should
|
||||
be cached after this function.
|
||||
block_size: Number of tokens in each block.
|
||||
hash_fn: The hash function to use for block hashes.
|
||||
"""
|
||||
if num_cached_blocks == num_full_blocks:
|
||||
return
|
||||
@ -138,7 +140,7 @@ class BlockPool:
|
||||
request, start_token_idx, end_token_idx, -1)
|
||||
|
||||
# Compute the hash of the current block.
|
||||
block_hash = hash_block_tokens(prev_block_hash_value,
|
||||
block_hash = hash_block_tokens(hash_fn, prev_block_hash_value,
|
||||
block_tokens, extra_keys)
|
||||
block_hashes.append(block_hash)
|
||||
|
||||
|
@ -5,7 +5,7 @@ from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv
|
||||
from vllm.utils import cdiv, sha256
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
|
||||
hash_request_tokens)
|
||||
@ -24,6 +24,7 @@ class KVCacheManager:
|
||||
max_model_len: int,
|
||||
sliding_window: Optional[int] = None,
|
||||
enable_caching: bool = True,
|
||||
caching_hash_algo: str = "builtin",
|
||||
num_preallocate_tokens: int = 64,
|
||||
log_stats: bool = False,
|
||||
) -> None:
|
||||
@ -33,6 +34,7 @@ class KVCacheManager:
|
||||
self.max_num_blocks_per_req = cdiv(max_model_len, block_size)
|
||||
self.sliding_window = sliding_window
|
||||
self.enable_caching = enable_caching
|
||||
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
|
||||
# FIXME: make prefix cache stats conditional on log_stats
|
||||
self.log_stats = log_stats
|
||||
# NOTE(woosuk): To avoid frequent block allocation, we preallocate some
|
||||
@ -109,7 +111,8 @@ class KVCacheManager:
|
||||
# if the scheduler has tried to schedule the request before.
|
||||
block_hashes = self.req_to_block_hashes[request.request_id]
|
||||
if not block_hashes:
|
||||
block_hashes = hash_request_tokens(self.block_size, request)
|
||||
block_hashes = hash_request_tokens(self.caching_hash_fn,
|
||||
self.block_size, request)
|
||||
self.req_to_block_hashes[request.request_id] = block_hashes
|
||||
|
||||
self.prefix_cache_stats.requests += 1
|
||||
@ -247,6 +250,7 @@ class KVCacheManager:
|
||||
num_cached_blocks=num_cached_blocks,
|
||||
num_full_blocks=num_full_blocks_after_append,
|
||||
block_size=self.block_size,
|
||||
hash_fn=self.caching_hash_fn,
|
||||
)
|
||||
|
||||
self.num_cached_block[
|
||||
|
@ -1,12 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""KV-Cache Utilities."""
|
||||
import os
|
||||
from collections import deque
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, NamedTuple, Optional
|
||||
from typing import Any, Callable, NamedTuple, Optional
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import sha256
|
||||
from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheGroupSpec,
|
||||
KVCacheSpec, KVCacheTensor)
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
@ -18,9 +20,8 @@ logger = init_logger(__name__)
|
||||
class BlockHashType(NamedTuple):
|
||||
"""Hash value of a block (int), the token IDs in the block, and extra keys.
|
||||
We keep a tuple of token IDs and extra keys to reduce the likelihood of
|
||||
hash collisions when the hash value is the same. But please note that
|
||||
hash collisions can still theoretically occur, albeit with an extremely
|
||||
low probability.
|
||||
hash collisions when the hash value is the same. By using SHA256 however,
|
||||
hash collisions are practically impossible.
|
||||
"""
|
||||
# Hash value of the block in an integer.
|
||||
hash_value: int
|
||||
@ -30,6 +31,20 @@ class BlockHashType(NamedTuple):
|
||||
extra_keys: Optional[Any] = None
|
||||
|
||||
|
||||
# The hash seed for the first block of the prefix block sequence.
|
||||
#
|
||||
# Even if the hash function is the builtin hash(), we use sha256 to generate
|
||||
# the initial hash to simplify the code. This is not performance critical
|
||||
# as it is done one per process.
|
||||
#
|
||||
# We use a random value to avoid hash collisions or PYTHONHASHSEED environment
|
||||
# variable if set such that processes can share the seed if needed.
|
||||
# This aligns with the behavior of Python's hash() function, which also uses
|
||||
# a random seed if PYTHONHASHSEED is not set.
|
||||
NONE_HASH = int.from_bytes(os.urandom(32), byteorder="big") if os.getenv(
|
||||
'PYTHONHASHSEED') is not None else sha256(os.getenv('PYTHONHASHSEED'))
|
||||
|
||||
|
||||
class PrefixCachingMetrics:
|
||||
"""Metrics for prefix caching with a hit rate of the most recent N requests.
|
||||
|
||||
@ -375,6 +390,7 @@ def generate_block_hash_extra_keys(
|
||||
|
||||
|
||||
def hash_block_tokens(
|
||||
hash_function: Callable,
|
||||
parent_block_hash: Optional[int],
|
||||
curr_block_token_ids: Sequence[int],
|
||||
extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHashType:
|
||||
@ -395,21 +411,16 @@ def hash_block_tokens(
|
||||
The entire tuple is used as the hash key of the block.
|
||||
"""
|
||||
if not parent_block_hash:
|
||||
# Note that we use 'None' as a string here instead of None because
|
||||
# as of Python 3.12, hash(None) returns a constant predictable value.
|
||||
# This could possibly make it easier to find and exploit hash
|
||||
# collisions. 'None' as a string will be hashed differently per process,
|
||||
# but consistently within the same process. This is the same as the
|
||||
# behavior of None prior to Python 3.12.
|
||||
parent_block_hash = hash('None')
|
||||
parent_block_hash = NONE_HASH
|
||||
|
||||
curr_block_token_ids_tuple = tuple(curr_block_token_ids)
|
||||
return BlockHashType(
|
||||
hash((parent_block_hash, curr_block_token_ids_tuple, extra_keys)),
|
||||
hash_function(
|
||||
(parent_block_hash, curr_block_token_ids_tuple, extra_keys)),
|
||||
curr_block_token_ids_tuple, extra_keys)
|
||||
|
||||
|
||||
def hash_request_tokens(block_size: int,
|
||||
def hash_request_tokens(hash_function: Any, block_size: int,
|
||||
request: Request) -> list[BlockHashType]:
|
||||
"""Computes hash values of a chain of blocks given a sequence of
|
||||
token IDs. The hash value is used for prefix caching.
|
||||
@ -441,7 +452,7 @@ def hash_request_tokens(block_size: int,
|
||||
req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
|
||||
request, start, end, curr_mm_idx)
|
||||
|
||||
block_hash = hash_block_tokens(parent_block_hash_value,
|
||||
block_hash = hash_block_tokens(hash_function, parent_block_hash_value,
|
||||
block_token_ids, req_extra_keys)
|
||||
ret.append(block_hash)
|
||||
parent_block_hash_value = block_hash.hash_value
|
||||
|
@ -61,6 +61,7 @@ class Scheduler(SchedulerInterface):
|
||||
max_model_len=self.max_model_len,
|
||||
sliding_window=self.cache_config.sliding_window,
|
||||
enable_caching=self.cache_config.enable_prefix_caching,
|
||||
caching_hash_algo=self.cache_config.prefix_caching_hash_algo,
|
||||
log_stats=self.log_stats)
|
||||
self.block_size = self.cache_config.block_size
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user