[Core] Improve hash collision avoidance in prefix caching (#12621)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Russell Bryant 2025-02-03 19:28:20 -05:00 committed by GitHub
parent 5095e96606
commit 73b35cca7f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 45 additions and 10 deletions

View File

@ -65,8 +65,8 @@ class TestPrefixCachingBlock:
previous_block = MagicMock(spec=PrefixCachingBlock) previous_block = MagicMock(spec=PrefixCachingBlock)
prev_block_hash = random.randint(0, 1000) prev_block_hash = random.randint(0, 1000)
previous_block.content_hash = (prev_block_hash previous_block.content_hash = (prev_block_hash if prev_block_has_hash
if prev_block_has_hash else None) else hash('None'))
num_to_fill = block_size if is_curr_block_full else random.randint( num_to_fill = block_size if is_curr_block_full else random.randint(
0, block_size - 1) 0, block_size - 1)

View File

@ -65,6 +65,15 @@ class PrefixCachingBlockAllocator(BlockAllocator):
from 0 to num_blocks - 1. from 0 to num_blocks - 1.
""" """
# Note that we use 'None' as a string here instead of None because
# as of Python 3.12, hash(None) returns a constant predictable value.
# This could possibly make it easier to find and exploit hash
# collisions. 'None' as a string will be hashed differently per process,
# but consistently within the same process. This is the same as the
# behavior of None prior to Python 3.12.
_none_hash: int = hash('None')
# Implements Block.Factory.
def __init__( def __init__(
self, self,
num_blocks: int, num_blocks: int,
@ -122,7 +131,6 @@ class PrefixCachingBlockAllocator(BlockAllocator):
self.metric_data = CacheMetricData() self.metric_data = CacheMetricData()
# Implements Block.Factory.
def _create_block( def _create_block(
self, self,
prev_block: Optional[Block], prev_block: Optional[Block],
@ -737,6 +745,14 @@ class PrefixCachingBlock(Block):
such as adapters that influence the block, apart from the token_ids. such as adapters that influence the block, apart from the token_ids.
""" """
# Note that we use 'None' as a string here instead of None because
# as of Python 3.12, hash(None) returns a constant predictable value.
# This could possibly make it easier to find and exploit hash
# collisions. 'None' as a string will be hashed differently per process,
# but consistently within the same process. This is the same as the
# behavior of None prior to Python 3.12.
_none_hash: int = hash('None')
def __init__( def __init__(
self, self,
prev_block: Optional[Block], prev_block: Optional[Block],
@ -891,13 +907,13 @@ class PrefixCachingBlock(Block):
is_first_block = self._prev_block is None is_first_block = self._prev_block is None
prev_block_hash = ( prev_block_hash = (
None if is_first_block else self._none_hash if is_first_block else
self._prev_block.content_hash # type: ignore self._prev_block.content_hash # type: ignore
) )
# Previous block exists but does not yet have a hash. # Previous block exists but does not yet have a hash.
# Return no hash in this case. # Return no hash in this case.
if prev_block_hash is None and not is_first_block: if prev_block_hash == self._none_hash and not is_first_block:
return None return None
self._cached_content_hash = PrefixCachingBlock.hash_block_tokens( self._cached_content_hash = PrefixCachingBlock.hash_block_tokens(
@ -907,8 +923,9 @@ class PrefixCachingBlock(Block):
extra_hash=self._extra_hash) extra_hash=self._extra_hash)
return self._cached_content_hash return self._cached_content_hash
@staticmethod @classmethod
def hash_block_tokens(is_first_block: bool, def hash_block_tokens(cls,
is_first_block: bool,
prev_block_hash: Optional[int], prev_block_hash: Optional[int],
cur_block_token_ids: List[int], cur_block_token_ids: List[int],
extra_hash: Optional[int] = None) -> int: extra_hash: Optional[int] = None) -> int:
@ -929,7 +946,8 @@ class PrefixCachingBlock(Block):
Returns: Returns:
- int: The computed hash value for the block. - int: The computed hash value for the block.
""" """
assert (prev_block_hash is None) == is_first_block if is_first_block and prev_block_hash is None:
prev_block_hash = cls._none_hash
return hash((is_first_block, prev_block_hash, *cur_block_token_ids, return hash((is_first_block, prev_block_hash, *cur_block_token_ids,
extra_hash)) extra_hash))
@ -949,6 +967,14 @@ class ComputedBlocksTracker:
cached block hashes in the allocator. cached block hashes in the allocator.
""" """
# Note that we use 'None' as a string here instead of None because
# as of Python 3.12, hash(None) returns a constant predictable value.
# This could possibly make it easier to find and exploit hash
# collisions. 'None' as a string will be hashed differently per process,
# but consistently within the same process. This is the same as the
# behavior of None prior to Python 3.12.
_none_hash: int = hash('None')
def __init__( def __init__(
self, self,
allocator: DeviceAwareBlockAllocator, allocator: DeviceAwareBlockAllocator,
@ -994,7 +1020,7 @@ class ComputedBlocksTracker:
# We need to know the hash of the previous block to compute the hash of # We need to know the hash of the previous block to compute the hash of
# the current block so that blocks could be uniquely identified across # the current block so that blocks could be uniquely identified across
# sequences of prefixes. # sequences of prefixes.
prev_block_hash = (None if cur_num_blocks_recorded == 0 else prev_block_hash = (self._none_hash if cur_num_blocks_recorded == 0 else
block_hashes_recorded[-1]) block_hashes_recorded[-1])
# Only update the computed block hashes for the new blocks # Only update the computed block hashes for the new blocks
for i in range(cur_num_blocks_recorded, num_computed_blocks): for i in range(cur_num_blocks_recorded, num_computed_blocks):
@ -1009,7 +1035,7 @@ class ComputedBlocksTracker:
# This has to be kept in sync with the allocator's hash # This has to be kept in sync with the allocator's hash
# calculation. # calculation.
block_hash = PrefixCachingBlock.hash_block_tokens( block_hash = PrefixCachingBlock.hash_block_tokens(
is_first_block=prev_block_hash is None, is_first_block=prev_block_hash == self._none_hash,
prev_block_hash=prev_block_hash, prev_block_hash=prev_block_hash,
cur_block_token_ids=block_token_ids, cur_block_token_ids=block_token_ids,
extra_hash=extra_hash, extra_hash=extra_hash,

View File

@ -263,6 +263,15 @@ def hash_block_tokens(
The hash value of the block and the token ids in the block. The hash value of the block and the token ids in the block.
The entire tuple is used as the hash key of the block. The entire tuple is used as the hash key of the block.
""" """
if not parent_block_hash:
# Note that we use 'None' as a string here instead of None because
# as of Python 3.12, hash(None) returns a constant predictable value.
# This could possibly make it easier to find and exploit hash
# collisions. 'None' as a string will be hashed differently per process,
# but consistently within the same process. This is the same as the
# behavior of None prior to Python 3.12.
parent_block_hash = hash('None')
curr_block_token_ids_tuple = tuple(curr_block_token_ids) curr_block_token_ids_tuple = tuple(curr_block_token_ids)
return BlockHashType( return BlockHashType(
hash((parent_block_hash, curr_block_token_ids_tuple, extra_keys)), hash((parent_block_hash, curr_block_token_ids_tuple, extra_keys)),