[Core] Improve hash collision avoidance in prefix caching (#12621)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Russell Bryant 2025-02-03 19:28:20 -05:00 committed by GitHub
parent 5095e96606
commit 73b35cca7f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 45 additions and 10 deletions

View File

@ -65,8 +65,8 @@ class TestPrefixCachingBlock:
previous_block = MagicMock(spec=PrefixCachingBlock)
prev_block_hash = random.randint(0, 1000)
previous_block.content_hash = (prev_block_hash
if prev_block_has_hash else None)
previous_block.content_hash = (prev_block_hash if prev_block_has_hash
else hash('None'))
num_to_fill = block_size if is_curr_block_full else random.randint(
0, block_size - 1)

View File

@ -65,6 +65,15 @@ class PrefixCachingBlockAllocator(BlockAllocator):
from 0 to num_blocks - 1.
"""
# Note that we use 'None' as a string here instead of None because
# as of Python 3.12, hash(None) returns a constant predictable value.
# This could possibly make it easier to find and exploit hash
# collisions. 'None' as a string will be hashed differently per process,
# but consistently within the same process. This is the same as the
# behavior of None prior to Python 3.12.
_none_hash: int = hash('None')
# Implements Block.Factory.
def __init__(
self,
num_blocks: int,
@ -122,7 +131,6 @@ class PrefixCachingBlockAllocator(BlockAllocator):
self.metric_data = CacheMetricData()
# Implements Block.Factory.
def _create_block(
self,
prev_block: Optional[Block],
@ -737,6 +745,14 @@ class PrefixCachingBlock(Block):
such as adapters that influence the block, apart from the token_ids.
"""
# Note that we use 'None' as a string here instead of None because
# as of Python 3.12, hash(None) returns a constant predictable value.
# This could possibly make it easier to find and exploit hash
# collisions. 'None' as a string will be hashed differently per process,
# but consistently within the same process. This is the same as the
# behavior of None prior to Python 3.12.
_none_hash: int = hash('None')
def __init__(
self,
prev_block: Optional[Block],
@ -891,13 +907,13 @@ class PrefixCachingBlock(Block):
is_first_block = self._prev_block is None
prev_block_hash = (
None if is_first_block else
self._none_hash if is_first_block else
self._prev_block.content_hash # type: ignore
)
# Previous block exists but does not yet have a hash.
# Return no hash in this case.
if prev_block_hash is None and not is_first_block:
if prev_block_hash == self._none_hash and not is_first_block:
return None
self._cached_content_hash = PrefixCachingBlock.hash_block_tokens(
@ -907,8 +923,9 @@ class PrefixCachingBlock(Block):
extra_hash=self._extra_hash)
return self._cached_content_hash
@staticmethod
def hash_block_tokens(is_first_block: bool,
@classmethod
def hash_block_tokens(cls,
is_first_block: bool,
prev_block_hash: Optional[int],
cur_block_token_ids: List[int],
extra_hash: Optional[int] = None) -> int:
@ -929,7 +946,8 @@ class PrefixCachingBlock(Block):
Returns:
- int: The computed hash value for the block.
"""
assert (prev_block_hash is None) == is_first_block
if is_first_block and prev_block_hash is None:
prev_block_hash = cls._none_hash
return hash((is_first_block, prev_block_hash, *cur_block_token_ids,
extra_hash))
@ -949,6 +967,14 @@ class ComputedBlocksTracker:
cached block hashes in the allocator.
"""
# Note that we use 'None' as a string here instead of None because
# as of Python 3.12, hash(None) returns a constant predictable value.
# This could possibly make it easier to find and exploit hash
# collisions. 'None' as a string will be hashed differently per process,
# but consistently within the same process. This is the same as the
# behavior of None prior to Python 3.12.
_none_hash: int = hash('None')
def __init__(
self,
allocator: DeviceAwareBlockAllocator,
@ -994,7 +1020,7 @@ class ComputedBlocksTracker:
# We need to know the hash of the previous block to compute the hash of
# the current block so that blocks could be uniquely identified across
# sequences of prefixes.
prev_block_hash = (None if cur_num_blocks_recorded == 0 else
prev_block_hash = (self._none_hash if cur_num_blocks_recorded == 0 else
block_hashes_recorded[-1])
# Only update the computed block hashes for the new blocks
for i in range(cur_num_blocks_recorded, num_computed_blocks):
@ -1009,7 +1035,7 @@ class ComputedBlocksTracker:
# This has to be kept in sync with the allocator's hash
# calculation.
block_hash = PrefixCachingBlock.hash_block_tokens(
is_first_block=prev_block_hash is None,
is_first_block=prev_block_hash == self._none_hash,
prev_block_hash=prev_block_hash,
cur_block_token_ids=block_token_ids,
extra_hash=extra_hash,

View File

@ -263,6 +263,15 @@ def hash_block_tokens(
The hash value of the block and the token ids in the block.
The entire tuple is used as the hash key of the block.
"""
if not parent_block_hash:
# Note that we use 'None' as a string here instead of None because
# as of Python 3.12, hash(None) returns a constant predictable value.
# This could possibly make it easier to find and exploit hash
# collisions. 'None' as a string will be hashed differently per process,
# but consistently within the same process. This is the same as the
# behavior of None prior to Python 3.12.
parent_block_hash = hash('None')
curr_block_token_ids_tuple = tuple(curr_block_token_ids)
return BlockHashType(
hash((parent_block_hash, curr_block_token_ids_tuple, extra_keys)),