Add Automatic Prefix Caching (#2762)

Co-authored-by: ElizaWszola <eliza@neuralmagic.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
Sage Moore 2024-03-02 03:50:01 -05:00 committed by GitHub
parent baee28c46c
commit ce4f5a29fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 615 additions and 289 deletions

View File

@ -73,10 +73,10 @@ def run_vllm(
enforce_eager: bool,
kv_cache_dtype: str,
device: str,
enable_prefix_caching: bool,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
model=model,
llm = LLM(model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
@ -87,7 +87,7 @@ def run_vllm(
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
device=device,
)
enable_prefix_caching=enable_prefix_caching)
# Add the requests to the engine.
for prompt, _, output_len in requests:
@ -211,7 +211,8 @@ def main(args: argparse.Namespace):
args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype,
args.max_model_len, args.enforce_eager,
args.kv_cache_dtype, args.device)
args.kv_cache_dtype, args.device,
args.enable_prefix_caching)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@ -302,6 +303,7 @@ if __name__ == "__main__":
default="cuda",
choices=["cuda"],
help='device type for vLLM execution, supporting CUDA only currently.')
parser.add_argument("--enable_prefix_caching", action='store_true')
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model

View File

@ -81,6 +81,10 @@ Below, you can find an explanation of every engine argument for vLLM:
Token block size for contiguous chunks of tokens.
.. option:: --enable-prefix-caching
Enables automatic prefix caching
.. option:: --seed <seed>
Random seed for operations.

View File

@ -37,20 +37,13 @@ for output in outputs:
print("-" * 80)
# -1 since the last token can change when concatenating prompts.
prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1
# The llm.generate call will batch all prompts and send the batch at once if resources allow.
# The prefix will only be cached after the first batch is processed, so we need to call generate once
# to calculate the prefix and cache it.
outputs = llm.generate(generating_prompts[0],
sampling_params,
prefix_pos=[prefix_pos])
outputs = llm.generate(generating_prompts[0], sampling_params)
# Subsequent batches can leverage the cached prefix
outputs = llm.generate(generating_prompts,
sampling_params,
prefix_pos=[prefix_pos] * len(generating_prompts))
outputs = llm.generate(generating_prompts, sampling_params)
# Print the outputs. You should see the same outputs as before
for output in outputs:

View File

@ -4,38 +4,73 @@ Run `pytest tests/prefix_caching/test_prefix_caching.py`.
"""
import pytest
from vllm import LLM, SamplingParams
prefix = (
"You are an expert school principal, skilled in effectively managing "
"faculty and staff. Draft 10-15 questions for a potential first grade "
"Head Teacher for my K-12, all-girls', independent school that emphasizes "
"community, joyful discovery, and life-long learning. The candidate is "
"coming in for a first-round panel interview for a 8th grade Math "
"teaching role. They have 5 years of previous teaching experience "
"as an assistant teacher at a co-ed, public school with experience "
"in middle school math teaching. Based on these information, fulfill "
"the following paragraph: ")
from vllm.core.block_manager import BlockAllocator
from vllm.utils import Device
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("max_tokens", [16])
def test_prefix_caching(
example_prompts,
model: str,
max_tokens: int,
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_blocks", [16])
def test_block_allocator(
block_size: int,
num_blocks: int,
):
llm = LLM(model=model)
# -1 since the last token can change when concatenating prompts.
prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1
prompts = [prefix + prompt for prompt in example_prompts]
sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs_without_prefix = llm.generate(prompts, sampling_params)
outputs_with_prefix = llm.generate(prompts,
sampling_params,
prefix_pos=[prefix_pos] * len(prompts))
for output_without_prefix, output_with_prefix in zip(
outputs_without_prefix, outputs_with_prefix):
assert (output_without_prefix.outputs[0].token_ids ==
output_with_prefix.outputs[0].token_ids)
assert len(llm.llm_engine.scheduler.prefix_pool.prefixes) == 1
block_hash = 1
block_allocator = BlockAllocator(Device.CPU,
block_size,
num_blocks,
enable_caching=True)
# Allocate two PysicalTokenBlocks with the same hash and check that they are the same PhysicalTokenBlock
first_block = block_allocator.allocate(block_hash, 0)
second_block = block_allocator.allocate(block_hash, 0)
assert (first_block == second_block)
assert (second_block.ref_count == 2)
# Free the first_block and confirm that the ref_count is correctly decremented on the second block
block_allocator.free(first_block)
assert (second_block.ref_count == 1)
# Free the second block
block_allocator.free(second_block)
# Reallocate the first block and confirm that, even after the block had its ref_count go to 0, we still get the same block back
first_block = block_allocator.allocate(block_hash, 0)
assert (first_block == second_block)
assert (first_block.block_hash == block_hash)
@pytest.mark.parametrize("num_blocks", [16])
def test_eviction(num_blocks: int, ):
block_size = 16
block_allocator = BlockAllocator(Device.CPU,
block_size,
num_blocks,
enable_caching=True)
blocks = []
for i in range(num_blocks):
# use i as the block_hash
blocks.append(block_allocator.allocate(i, 0))
#Free all blocks
for block in blocks:
block_allocator.free(block)
# Allocate a new block and confirm that it's the first block freed. I.E The Least Recently Used block
new_block_hash = block_size
new_block = block_allocator.allocate(new_block_hash, 0)
assert (new_block == blocks[0])
assert (new_block.block_hash == new_block_hash)
# Reallocate the second in blocks to remove it from the free list
realloc_block_hash = 1
realloc_block = block_allocator.allocate(realloc_block_hash, 0)
assert (realloc_block == blocks[realloc_block_hash])
assert (realloc_block.block_hash == realloc_block_hash)
# Allocate a new block and confirm that it's not the realloc_block, since the realloc_block shouldn't be in the free list
new_block_hash = block_size + 1
new_block = block_allocator.allocate(new_block_hash, 0)
assert (realloc_block != new_block)
assert (new_block.block_hash == new_block_hash)
assert (new_block.block_number == 2)

View File

@ -0,0 +1,76 @@
"""Test hashing of cache blocks.
Run `pytest tests/test_cache_block_hashing.py`.
"""
import pytest
from vllm.transformers_utils.tokenizer import TokenizerGroup
from vllm.sequence import Sequence
# Make two prefixes with different first blocks.
prefix_start = [("You are an expert"), ("You are a")]
prefix_common = (
" school principal, skilled in effectively managing "
"faculty and staff. Draft 10-15 questions for a potential first grade "
"Head Teacher for my K-12, all-girls', independent school that emphasizes "
"community, joyful discovery, and life-long learning. The candidate is "
"coming in for a first-round panel interview for a 8th grade Math "
"teaching role. They have 5 years of previous teaching experience "
"as an assistant teacher at a co-ed, public school with experience "
"in middle school math teaching. Based on this, fulfill "
"the following: ")
prefixes = [start + prefix_common for start in prefix_start]
# Sample prompts.
sample_prompts = [
"Hello, my name is", "The president of the United States is",
"The capital of France is", "The future of AI is"
]
# Helper function.
def flatten_2d(li):
return [lss for ls in li for lss in ls]
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("max_num_seqs", [256])
def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int):
tokenizer = TokenizerGroup(
tokenizer_id="facebook/opt-125m",
enable_lora=False,
max_num_seqs=max_num_seqs,
max_input_length=None,
)
hashes = []
for prefix in prefixes:
hashes.append([])
prompts = [prefix + prompt for prompt in sample_prompts]
seq_id = 0
for prompt in prompts:
hashes[-1].append([])
prompt_token_ids = tokenizer.encode(prompt)
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
num_blocks = len(prompt_token_ids) // block_size
for idx in range(num_blocks):
hashes[-1][-1].append(seq.hash_of_block(idx))
seq_id += 1
# Check that hashes made with two prefixes with different first blocks are
# different everywhere.
for hash0, hash1 in zip(flatten_2d(hashes[0]), flatten_2d(hashes[1])):
assert (hash0 != hash1)
# Check that hashes of different prompts made with the same prefix are the
# same until the hashes that contain the prompt.
for hash_pref in hashes:
same_hashes = [tuple(h[:-1]) for h in hash_pref]
different_hashes = [h[-1] for h in hash_pref]
assert (len(set(same_hashes)) == 1)
assert (len(set(different_hashes)) == len(different_hashes))

View File

@ -5,6 +5,8 @@ from vllm.utils import Device
_BLANK_TOKEN_ID = -1
DEFAULT_LAST_ACCESSED_TIME = -1
class LogicalTokenBlock:
"""A block that stores a contiguous chunk of tokens from left to right.
@ -55,17 +57,27 @@ class PhysicalTokenBlock:
device: Device,
block_number: int,
block_size: int,
block_hash: int,
num_hashed_tokens: int,
) -> None:
self.device = device
self.block_number = block_number
self.block_size = block_size
self.block_hash = block_hash
self.num_hashed_tokens = num_hashed_tokens
self.ref_count = 0
self.last_accessed = DEFAULT_LAST_ACCESSED_TIME
self.computed = False
def __repr__(self) -> str:
return (f'PhysicalTokenBlock(device={self.device}, '
f'block_number={self.block_number}, '
f'ref_count={self.ref_count})')
f'num_hashed_tokens={self.num_hashed_tokens}, '
f'ref_count={self.ref_count}, '
f'last_accessed={self.last_accessed}, '
f'computed={self.computed})')
# Mapping: logical block number -> physical block.

View File

@ -303,12 +303,14 @@ class CacheConfig:
swap_space: int,
cache_dtype: str,
sliding_window: Optional[int] = None,
enable_prefix_caching: bool = False,
) -> None:
self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization
self.swap_space_bytes = swap_space * _GB
self.cache_dtype = cache_dtype
self.sliding_window = sliding_window
self.enable_prefix_caching = enable_prefix_caching
self._verify_args()
self._verify_cache_dtype()

View File

@ -1,10 +1,13 @@
"""A block manager that manages token blocks."""
import enum
from itertools import count
from os.path import commonprefix
from typing import Dict, List, Optional, Set, Tuple
from vllm.block import BlockTable, PhysicalTokenBlock
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.utils import Device
from vllm.core.evictor import Evictor, EvictionPolicy, make_evictor
class BlockAllocator:
@ -15,29 +18,68 @@ class BlockAllocator:
the reference count becomes zero, the block is added back to the free list.
"""
def __init__(
self,
def __init__(self,
device: Device,
block_size: int,
num_blocks: int,
) -> None:
eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
enable_caching: bool = False) -> None:
self.device = device
self.block_size = block_size
self.num_blocks = num_blocks
self.enable_caching = enable_caching
# Initialize the free blocks.
self.free_blocks: BlockTable = []
for i in range(num_blocks):
block = PhysicalTokenBlock(device=device,
block_number=i,
block_size=block_size)
self.free_blocks.append(block)
self.current_num_blocks = 0
self.cached_blocks: Dict[int, PhysicalTokenBlock] = {}
def allocate(self) -> PhysicalTokenBlock:
if not self.free_blocks:
raise ValueError("Out of memory! No free blocks are available.")
block = self.free_blocks.pop()
block.ref_count = 1
# Switch over to FIFO eviction when caching is disabled
if not self.enable_caching:
eviction_policy = EvictionPolicy.FIFO
self.evictor: Evictor = make_evictor(eviction_policy)
self.default_hash_ctr = count()
def allocate_block(self, block_hash: int,
num_hashed_tokens: int) -> PhysicalTokenBlock:
if self.current_num_blocks == self.num_blocks:
block = self.evictor.evict()
block.block_hash = block_hash
block.num_hashed_tokens = num_hashed_tokens
return block
block = PhysicalTokenBlock(device=self.device,
block_number=self.current_num_blocks,
block_size=self.block_size,
block_hash=block_hash,
num_hashed_tokens=num_hashed_tokens)
self.current_num_blocks += 1
return block
def allocate(self,
block_hash: Optional[int] = None,
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
# If caching is disabled, just allocate a new block and return it
if not self.enable_caching:
block = self.allocate_block(next(self.default_hash_ctr),
num_hashed_tokens)
block.ref_count += 1
return block
if block_hash is None:
block_hash = next(self.default_hash_ctr)
if block_hash in self.evictor:
assert block_hash not in self.cached_blocks
block = self.evictor.remove(block_hash)
assert block.ref_count == 0
self.cached_blocks[block_hash] = block
block.ref_count += 1
assert block.block_hash == block_hash
return block
if block_hash not in self.cached_blocks:
self.cached_blocks[block_hash] = self.allocate_block(
block_hash, num_hashed_tokens)
block = self.cached_blocks[block_hash]
assert block.block_hash == block_hash
block.ref_count += 1
return block
def free(self, block: PhysicalTokenBlock) -> None:
@ -45,10 +87,27 @@ class BlockAllocator:
raise ValueError(f"Double free! {block} is already freed.")
block.ref_count -= 1
if block.ref_count == 0:
self.free_blocks.append(block)
assert block.block_hash not in self.evictor
self.evictor.add(block)
# If caching is enabled, remove the block from the cached_blocks
if self.enable_caching:
del self.cached_blocks[block.block_hash]
def get_num_free_blocks(self) -> int:
return len(self.free_blocks)
return self.num_blocks - self.current_num_blocks + self.evictor.num_blocks
def contains_block(self, block_hash: int) -> bool:
return block_hash in self.cached_blocks or block_hash in self.evictor
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
# If caching is enabled, update the hash of block and the cached_blocks dictionary.
if self.enable_caching:
assert not self.contains_block(block_hash)
old_hash = block.block_hash
block.block_hash = block_hash
del self.cached_blocks[old_hash]
self.cached_blocks[block_hash] = block
class AllocStatus(enum.Enum):
@ -75,6 +134,7 @@ class BlockSpaceManager:
num_cpu_blocks: int,
watermark: float = 0.01,
sliding_window: Optional[int] = None,
enable_caching: bool = False,
) -> None:
self.block_size = block_size
self.num_total_gpu_blocks = num_gpu_blocks
@ -89,11 +149,17 @@ class BlockSpaceManager:
self.watermark = watermark
assert watermark >= 0.0
self.enable_caching = enable_caching
self.watermark_blocks = int(watermark * num_gpu_blocks)
self.gpu_allocator = BlockAllocator(Device.GPU, block_size,
num_gpu_blocks)
self.cpu_allocator = BlockAllocator(Device.CPU, block_size,
num_cpu_blocks)
self.gpu_allocator = BlockAllocator(Device.GPU,
block_size,
num_gpu_blocks,
enable_caching=enable_caching)
self.cpu_allocator = BlockAllocator(Device.CPU,
block_size,
num_cpu_blocks,
enable_caching=enable_caching)
# Mapping: seq_id -> BlockTable.
self.block_tables: Dict[int, BlockTable] = {}
@ -103,9 +169,6 @@ class BlockSpaceManager:
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
num_required_blocks = len(seq.logical_token_blocks)
if seq_group.prefix is not None and seq_group.prefix.allocated:
num_required_blocks -= seq_group.prefix.get_num_blocks()
if self.block_sliding_window is not None:
num_required_blocks = min(num_required_blocks,
self.block_sliding_window)
@ -129,36 +192,16 @@ class BlockSpaceManager:
num_prompt_blocks = len(seq.logical_token_blocks)
block_table: BlockTable = []
prefix_block_table: BlockTable = []
num_prefix_blocks = 0
prefix = seq_group.prefix
if prefix is not None and prefix.allocated:
# Prefix has already been allocated. Use the existing block table.
num_prompt_blocks -= prefix.get_num_blocks()
for block in prefix.block_table:
block.ref_count += seq_group.num_seqs()
block_table.append(block)
for logical_idx in range(num_prompt_blocks):
if (self.block_sliding_window is not None
and logical_idx >= self.block_sliding_window):
block = block_table[logical_idx % self.block_sliding_window]
else:
block = self.gpu_allocator.allocate()
# Set the reference counts of the token blocks.
block.ref_count = seq_group.num_seqs()
block = self.gpu_allocator.allocate(
seq.hash_of_block(logical_idx),
seq.num_hashed_tokens_of_block(logical_idx))
block_table.append(block)
if prefix is not None and not prefix.allocated:
# Allocate blocks for the prefix, we will compute the prefix's
# KV cache in this run.
num_prefix_blocks = prefix.get_num_blocks()
prefix_block_table = block_table[:num_prefix_blocks]
for block in prefix_block_table:
block.ref_count += 1
prefix.set_block_table(prefix_block_table)
# Assign the block table for each sequence.
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
self.block_tables[seq.seq_id] = block_table.copy()
@ -170,12 +213,72 @@ class BlockSpaceManager:
num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
return num_seqs <= num_free_gpu_blocks
def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]:
def _promote_last_block(
self,
seq: Sequence,
last_block: PhysicalTokenBlock,
) -> PhysicalTokenBlock:
# Compute a new hash for the block so that it can be shared by other Sequences
new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
# if new_hash is already in the cached table, then free last_block and return the cached version
if self.gpu_allocator.contains_block(new_hash):
self.gpu_allocator.free(last_block)
return self.gpu_allocator.allocate(new_hash)
else:
self.gpu_allocator.update_hash(new_hash, last_block)
return last_block
def _is_last_block_full(
self,
seq: Sequence,
) -> bool:
token_ids_len = len(seq.data.get_token_ids())
return token_ids_len > 0 and token_ids_len % seq.block_size == 0
def _is_last_block(
self,
seq: Sequence,
index: int,
) -> bool:
return index == len(seq.logical_token_blocks) - 1
def _maybe_promote_last_block(
self,
seq: Sequence,
last_block: PhysicalTokenBlock,
) -> PhysicalTokenBlock:
if self._is_last_block_full(seq):
return self._promote_last_block(seq, last_block)
else:
return last_block
def _allocate_last_physical_block(
self,
seq: Sequence,
) -> PhysicalTokenBlock:
block_hash: Optional[int] = None
if (self._is_last_block_full(seq)):
block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
num_hashed_tokens = seq.num_hashed_tokens_of_block(
len(seq.logical_token_blocks) - 1)
new_block = self.gpu_allocator.allocate(block_hash, num_hashed_tokens)
if block_hash is None:
assert new_block.ref_count == 1
return new_block
def append_slot(
self,
seq: Sequence,
) -> Optional[Tuple[int, int]]:
"""Allocate a physical slot for a new token."""
logical_blocks = seq.logical_token_blocks
block_table = self.block_tables[seq.seq_id]
# If we need to allocate a new physical block
if len(block_table) < len(logical_blocks):
# Currently this code only supports adding one physical block
assert len(block_table) == len(logical_blocks) - 1
if (self.block_sliding_window
and len(block_table) >= self.block_sliding_window):
# reuse a block
@ -184,8 +287,8 @@ class BlockSpaceManager:
else:
# The sequence has a new logical block.
# Allocate a new physical block.
block = self.gpu_allocator.allocate()
block_table.append(block)
new_block = self._allocate_last_physical_block(seq)
block_table.append(new_block)
return None
# We want to append the token to the last physical block.
@ -193,11 +296,15 @@ class BlockSpaceManager:
assert last_block.device == Device.GPU
if last_block.ref_count == 1:
# Not shared with other sequences. Appendable.
# If the last block is now complete, promote it to a full block so that it can be shared
new_block = self._maybe_promote_last_block(seq, last_block)
block_table[-1] = new_block
return None
else:
# The last block is shared with other sequences.
# Copy on Write: Allocate a new block and copy the tokens.
new_block = self.gpu_allocator.allocate()
new_block = self._allocate_last_physical_block(seq)
block_table[-1] = new_block
self.gpu_allocator.free(last_block)
return last_block.block_number, new_block.block_number
@ -233,25 +340,18 @@ class BlockSpaceManager:
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
# CPU block -> GPU block.
if seq_group.prefix is not None:
# make sure to swap in the prefix first
assert seq_group.prefix.allocated and seq_group.prefix.computed
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
new_block_table: BlockTable = []
block_table = self.block_tables[seq.seq_id]
if seq_group.prefix is not None:
for block in seq_group.prefix.block_table:
new_block_table.append(block)
block.ref_count += 1
for cpu_block in block_table:
if cpu_block in mapping:
gpu_block = mapping[cpu_block]
gpu_block.ref_count += 1
else:
gpu_block = self.gpu_allocator.allocate()
gpu_block = self.gpu_allocator.allocate(
cpu_block.block_hash, cpu_block.num_hashed_tokens)
mapping[cpu_block] = gpu_block
new_block_table.append(gpu_block)
# Free the CPU block swapped in to GPU.
@ -276,17 +376,12 @@ class BlockSpaceManager:
block_table = self.block_tables[seq.seq_id]
for gpu_block in block_table:
if (seq_group.prefix is not None
and gpu_block in seq_group.prefix.block_table):
# NOTE: We do not swap out the prefix blocks for now.
self.gpu_allocator.free(gpu_block)
continue
if gpu_block in mapping:
cpu_block = mapping[gpu_block]
cpu_block.ref_count += 1
else:
cpu_block = self.cpu_allocator.allocate()
cpu_block = self.cpu_allocator.allocate(
gpu_block.block_hash, gpu_block.num_hashed_tokens)
mapping[gpu_block] = cpu_block
new_block_table.append(cpu_block)
# Free the GPU block swapped out to CPU.
@ -328,3 +423,49 @@ class BlockSpaceManager:
def get_num_free_cpu_blocks(self) -> int:
return self.cpu_allocator.get_num_free_blocks()
def access_all_blocks_in_seq(
self,
seq: Sequence,
access_time: float,
) -> None:
block_table = self.block_tables[seq.seq_id]
for block in block_table:
block.last_accessed = access_time
def compute_last_full_block_in_seq(self, seq: Sequence):
if seq.seq_id not in self.block_tables:
return
max_full_block = seq.get_len() // seq.block_size - 1
block_table = self.block_tables[seq.seq_id]
if max_full_block == -1:
return
block_table[max_full_block].computed = True
def get_all_block_ids_till_computed(self, seq: Sequence) -> List[int]:
if seq.seq_id not in self.block_tables:
return []
block_table = self.block_tables[seq.seq_id]
for block_idx in reversed(range(len(block_table))):
if block_table[block_idx].computed:
return [b.block_number for b in block_table[:block_idx + 1]]
return []
# Can return non-empty result only with prefix caching enabled.
def get_common_computed_block_ids(self,
seq_group: SequenceGroup) -> List[int]:
if not self.enable_caching:
return []
ids_list = [
self.get_all_block_ids_till_computed(seq)
for seq in iter(seq_group.seqs_dict.values())
]
return commonprefix([ids for ids in ids_list if ids != []])
# We only mark the last full block because with prefix caching,
# all blocks until the marked one are guaranteed to be computed.
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
if self.enable_caching:
for seq in seq_group.seqs_dict.values():
self.compute_last_full_block_in_seq(seq)

161
vllm/core/evictor.py Normal file
View File

@ -0,0 +1,161 @@
import enum
from typing import Dict, List, Optional
from abc import ABC, abstractmethod, abstractproperty
from vllm.block import PhysicalTokenBlock
class EvictionPolicy(enum.Enum):
"""Enum for eviction policy used by make_evictor to instantiate the correct
Evictor subclass.
"""
LRU = enum.auto()
FIFO = enum.auto()
class Evictor(ABC):
"""The Evictor subclasses should be used by the BlockAllocator class to
handle eviction of freed PhysicalTokenBlocks.
"""
@abstractmethod
def __init__(self):
pass
@abstractmethod
def __contains__(self, block_hash: int) -> bool:
pass
@abstractmethod
def evict(self) -> PhysicalTokenBlock:
"""Runs the eviction algorithm and returns the evicted block"""
pass
@abstractmethod
def add(self, block: PhysicalTokenBlock):
"""Adds block to the evictor, making it a candidate for eviction"""
pass
@abstractmethod
def remove(self, block_hash: int) -> PhysicalTokenBlock:
"""Simply removes the block with the hash value block_hash from the
evictor. Caller is responsible for making sure that block_hash is contained
in the evictor before calling remove. Should be used to "bring back" blocks
that have been freed but not evicted yet.
"""
pass
@abstractproperty
def num_blocks(self) -> int:
pass
class LRUEvictor(Evictor):
"""Evicts in a least-recently-used order using the last_accessed timestamp
that's recorded in the PhysicalTokenBlock. If there are multiple blocks with
the same last_accessed time, then the one with the largest num_hashed_tokens
will be evicted. If two blocks each have the lowest last_accessed time and
highest num_hashed_tokens value, then one will be chose arbitrarily
"""
def __init__(self):
self.free_table: Dict[int, PhysicalTokenBlock] = {}
def __contains__(self, block_hash: int) -> bool:
return block_hash in self.free_table
# TODO: The performance of this evict function can be optimized further.
def evict(self) -> PhysicalTokenBlock:
free_blocks: List[PhysicalTokenBlock] = list(self.free_table.values())
if len(free_blocks) == 0:
raise ValueError("No usable cache memory left")
# Find lowest timestamp
lowest_timestamp = free_blocks[0].last_accessed
for block in free_blocks:
if block.last_accessed < lowest_timestamp:
lowest_timestamp = block.last_accessed
# Find all blocks with the lowest timestamp
least_recent: List[PhysicalTokenBlock] = []
for block in free_blocks:
if block.last_accessed == lowest_timestamp:
least_recent.append(block)
# Find highest prefix count per block
highest_num_hashed_tokens = 0
for block in least_recent:
if block.num_hashed_tokens > highest_num_hashed_tokens:
highest_num_hashed_tokens = block.num_hashed_tokens
evicted_block: Optional[PhysicalTokenBlock] = None
# Find the first block with the lowest timestamp
for block in least_recent:
if block.num_hashed_tokens == highest_num_hashed_tokens:
evicted_block = block
break
assert evicted_block is not None
del self.free_table[evicted_block.block_hash]
evicted_block.computed = False
return evicted_block
def add(self, block: PhysicalTokenBlock):
self.free_table[block.block_hash] = block
def remove(self, block_hash: int) -> PhysicalTokenBlock:
if block_hash not in self.free_table:
raise ValueError(
"Attempting to remove block that's not in the evictor")
block: PhysicalTokenBlock = self.free_table[block_hash]
del self.free_table[block_hash]
return block
@property
def num_blocks(self) -> int:
return len(self.free_table)
class RandomEvictor(Evictor):
"""Evicts in a first-in-first-out order"""
def __init__(self):
self.free_table: Dict[int, PhysicalTokenBlock] = {}
def __contains__(self, block_hash: int) -> bool:
return block_hash in self.free_table
def evict(self) -> PhysicalTokenBlock:
if len(self.free_table) == 0:
raise ValueError("No usable cache memory left")
evicted_block = next(iter(self.free_table.values()))
evicted_block.computed = False
del self.free_table[evicted_block.block_hash]
return evicted_block
def add(self, block: PhysicalTokenBlock):
self.free_table[block.block_hash] = block
def remove(self, block_hash: int) -> PhysicalTokenBlock:
if block_hash not in self.free_table:
raise ValueError(
"Attempting to remove block that's not in the evictor")
block: PhysicalTokenBlock = self.free_table[block_hash]
del self.free_table[block_hash]
return block
@property
def num_blocks(self) -> int:
return len(self.free_table)
def make_evictor(eviction_policy: EvictionPolicy) -> Evictor:
if eviction_policy == EvictionPolicy.LRU:
return LRUEvictor()
elif eviction_policy == EvictionPolicy.FIFO:
return RandomEvictor()
else:
raise ValueError(f"Unknown cache eviction policy: {eviction_policy}")

View File

@ -10,7 +10,6 @@ from vllm.lora.request import LoRARequest
from vllm.logger import init_logger
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceStatus)
from vllm.prefix import PrefixPool
logger = init_logger(__name__)
@ -95,10 +94,8 @@ class Scheduler:
block_size=self.cache_config.block_size,
num_gpu_blocks=self.cache_config.num_gpu_blocks,
num_cpu_blocks=self.cache_config.num_cpu_blocks,
sliding_window=self.cache_config.sliding_window)
# Create the prefix pool to cache the prefixes.
self.prefix_pool = PrefixPool(self.cache_config.block_size)
sliding_window=self.cache_config.sliding_window,
enable_caching=self.cache_config.enable_prefix_caching)
# Sequence groups in the WAITING state.
self.waiting: Deque[SequenceGroup] = deque()
@ -374,10 +371,12 @@ class Scheduler:
seq_data: Dict[int, SequenceData] = {}
block_tables: Dict[int, List[int]] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq_id = seq.seq_id
seq_data[seq_id] = seq.data
block_tables[seq_id] = self.block_manager.get_block_table(seq)
self.block_manager.access_all_blocks_in_seq(seq, now)
seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id,
@ -386,7 +385,8 @@ class Scheduler:
sampling_params=seq_group.sampling_params,
block_tables=block_tables,
lora_request=seq_group.lora_request,
prefix=seq_group.prefix,
computed_block_nums=self.block_manager.
get_common_computed_block_ids(seq_group),
state=seq_group.state,
)
seq_group_metadata_list.append(seq_group_metadata)
@ -496,3 +496,6 @@ class Scheduler:
blocks_to_swap_out.update(mapping)
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq.status = SequenceStatus.SWAPPED
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
self.block_manager.mark_blocks_as_computed(seq_group)

View File

@ -25,6 +25,7 @@ class EngineArgs:
tensor_parallel_size: int = 1
max_parallel_loading_workers: Optional[int] = None
block_size: int = 16
enable_prefix_caching: bool = False
swap_space: int = 4 # GiB
gpu_memory_utilization: float = 0.90
max_num_batched_tokens: Optional[int] = None
@ -173,6 +174,11 @@ class EngineArgs:
default=EngineArgs.block_size,
choices=[8, 16, 32, 128],
help='token block size')
parser.add_argument('--enable-prefix-caching',
action='store_true',
help='Enables automatic prefix caching')
parser.add_argument('--seed',
type=int,
default=EngineArgs.seed,
@ -293,7 +299,8 @@ class EngineArgs:
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype,
model_config.get_sliding_window())
model_config.get_sliding_window(),
self.enable_prefix_caching)
parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size,
self.worker_use_ray,

View File

@ -225,7 +225,6 @@ class _AsyncLLMEngine(LLMEngine):
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
@ -245,7 +244,6 @@ class _AsyncLLMEngine(LLMEngine):
sampling_params=sampling_params,
arrival_time=arrival_time,
lora_request=lora_request,
prefix_pos=prefix_pos,
)
async def _run_workers_async(
@ -422,7 +420,6 @@ class AsyncLLMEngine:
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
) -> AsyncStream:
if self.log_requests:
shortened_prompt = prompt
@ -435,7 +432,6 @@ class AsyncLLMEngine:
max_log_len]
logger.info(f"Received request {request_id}: "
f"prompt: {shortened_prompt!r}, "
f"prefix_pos: {prefix_pos},"
f"sampling_params: {sampling_params}, "
f"prompt_token_ids: {shortened_token_ids}, "
f"lora_request: {lora_request}.")
@ -472,8 +468,7 @@ class AsyncLLMEngine:
sampling_params=sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
prefix_pos=prefix_pos)
lora_request=lora_request)
return stream
@ -484,7 +479,6 @@ class AsyncLLMEngine:
request_id: str,
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request.
@ -500,11 +494,6 @@ class AsyncLLMEngine:
prompt_token_ids: The token IDs of the prompt. If None, we
use the tokenizer to convert the prompts to token IDs.
lora_request: LoRA request to use for generation, if any.
prefix_pos: If not None, we use the given position as the prefix
position for each prompt. We will cache the prefix's KV
cache and reuse it for the next request with the same prefix.
This is an experimental feature, and may be replaced with
automatic prefix caching in the future.
Yields:
The output `RequestOutput` objects from the LLMEngine for the
@ -565,7 +554,6 @@ class AsyncLLMEngine:
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
prefix_pos=prefix_pos,
)
async for request_output in stream:

View File

@ -415,7 +415,6 @@ class LLMEngine:
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
) -> None:
"""Add a request to the engine's request pool.
@ -432,11 +431,6 @@ class LLMEngine:
use the tokenizer to convert the prompts to token IDs.
arrival_time: The arrival time of the request. If None, we use
the current monotonic time.
prefix_pos: If not None, we use the given position as the prefix
position for each prompt. We will cache the prefix's KV
cache and reuse it for the next request with the same prefix.
This is an experimental feature, and may be replaced with
automatic prefix caching in the future.
Details:
- Set arrival_time to the current time if it is None.
@ -479,18 +473,13 @@ class LLMEngine:
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
lora_request)
# Check whether the input specifies prefix
prefix = self.scheduler.prefix_pool.add_or_get_prefix(
prompt_token_ids[:prefix_pos], lora_request.lora_int_id
if lora_request else 0) if prefix_pos is not None else None
# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone()
# Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params,
arrival_time, lora_request, prefix)
arrival_time, lora_request)
# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
@ -752,6 +741,13 @@ class LLMEngine:
now = time.time()
# Update the scheduled sequence groups with the model outputs.
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
# If prefix caching is enabled, mark all blocks in the sequence groups
# as completed so that future requests don't attempt to recompute them
if self.cache_config.enable_prefix_caching:
for seq_group in scheduled_seq_groups:
self.scheduler.mark_blocks_as_computed(seq_group)
for seq_group, outputs in zip(scheduled_seq_groups, output):
self._process_sequence_group_outputs(seq_group, outputs)
@ -768,12 +764,6 @@ class LLMEngine:
request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output)
# Update prefix state, now all the uncomputed prefixes are computed.
for seq_group in scheduled_seq_groups:
if (seq_group.prefix is not None and seq_group.prefix.allocated
and not seq_group.prefix.computed):
seq_group.prefix.computed = True
# Log stats.
if self.log_stats:
self.stat_logger.log(self._get_stats(scheduler_outputs))

View File

@ -39,15 +39,11 @@ async def generate(request: Request) -> Response:
"""
request_dict = await request.json()
prompt = request_dict.pop("prompt")
prefix_pos = request_dict.pop("prefix_pos", None)
stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()
results_generator = engine.generate(prompt,
sampling_params,
request_id,
prefix_pos=prefix_pos)
results_generator = engine.generate(prompt, sampling_params, request_id)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:

View File

@ -124,7 +124,6 @@ class LLM:
prompts: Optional[Union[str, List[str]]] = None,
sampling_params: Optional[SamplingParams] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
prefix_pos: Optional[Union[int, List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
) -> List[RequestOutput]:
@ -140,11 +139,6 @@ class LLM:
None, we use the default sampling parameters.
prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs.
prefix_pos: If not None, we use the given position as the prefix
position for each prompt. We will cache the prefix's KV
cache and reuse it for the next request with the same prefix.
This is an experimental feature, and may be replaced with
automatic prefix caching in the future.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
@ -171,14 +165,12 @@ class LLM:
prompt_token_ids)
for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None
prefix_pos_i = prefix_pos[i] if prefix_pos is not None else None
token_ids = None if prompt_token_ids is None else prompt_token_ids[
i]
self._add_request(prompt,
sampling_params,
token_ids,
lora_request=lora_request,
prefix_pos=prefix_pos_i)
lora_request=lora_request)
return self._run_engine(use_tqdm)
def _add_request(
@ -187,15 +179,13 @@ class LLM:
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]],
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
) -> None:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(request_id,
prompt,
sampling_params,
prompt_token_ids,
lora_request=lora_request,
prefix_pos=prefix_pos)
lora_request=lora_request)
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
# Initialize tqdm.

View File

@ -1,87 +0,0 @@
from typing import Dict, List, Sequence, Tuple, Optional
from vllm.block import BlockTable
class Prefix:
"""Data and states associated with a prefix of prompt tokens for multiple
sequence groups.
NOTE: This feature is experimental and may be replaced with automatic
prefix caching in the future.
Args:
token_ids: The token ids of the prefix.
block_size: The block size of the executed model.
"""
def __init__(
self,
token_ids: Sequence[int],
block_size: int,
) -> None:
self.token_ids = tuple(token_ids)
self.block_size = block_size
self.length = len(token_ids)
self.hash = hash(token_ids)
assert self.length % block_size == 0
self.block_table: Optional[BlockTable] = None
self.computed = False
@property
def allocated(self) -> bool:
return self.block_table is not None
def get_num_blocks(self) -> int:
return self.length // self.block_size
def get_block_numbers(self) -> List[int]:
return [block.block_number for block in self.block_table]
def get_length(self) -> int:
return self.length
def __hash__(self) -> int:
return self.hash
def set_block_table(self, block_table: BlockTable) -> None:
self.block_table = block_table.copy()
class PrefixPool:
"""Manages all the prompt prefixes.
NOTE: This feature is experimental and may be replaced with automatic
prefix caching in the future.
Args:
block_size: The block size of the executed model.
Attributes:
prefixes: A list of all the prefixes.
block_size: The block size of the executed model.
"""
def __init__(
self,
block_size: int,
) -> None:
# TODO(zhuohan): Add a capacity limit to the prefix pool.
self.prefixes: Dict[int, Prefix] = {}
self.block_size = block_size
def _truncate_token_ids(self, token_ids: Sequence[int]) -> Tuple[int]:
new_length = len(token_ids) // self.block_size * self.block_size
return tuple(token_ids[:new_length])
def add_or_get_prefix(self, token_ids: Sequence[int],
lora_int_id: int) -> Optional[Prefix]:
token_ids = self._truncate_token_ids(token_ids)
if len(token_ids) == 0:
# Prefix is empty.
return None
prefix = Prefix(token_ids, self.block_size)
prefix_hash = hash((prefix, lora_int_id))
if prefix_hash not in self.prefixes:
self.prefixes[prefix_hash] = prefix
return self.prefixes[prefix_hash]

View File

@ -5,7 +5,6 @@ from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from vllm.block import LogicalTokenBlock
from vllm.prefix import Prefix
from vllm.sampling_params import SamplingParams
from vllm.lora.request import LoRARequest
@ -161,6 +160,16 @@ class Sequence:
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
# TODO The current hashing function is O(L^2). We should optimize this in
# the future.
def hash_of_block(self, logical_idx: int) -> int:
# Compute the number of tokens in the sequence
num_tokens = self.num_hashed_tokens_of_block(logical_idx)
return hash(tuple(self.data.get_token_ids()[0:num_tokens]))
def num_hashed_tokens_of_block(self, logical_idx: int):
return logical_idx * self.block_size + self.block_size
def _append_logical_block(self) -> None:
block = LogicalTokenBlock(
block_number=len(self.logical_token_blocks),
@ -265,7 +274,6 @@ class SequenceGroup:
sampling_params: The sampling parameters used to generate the outputs.
arrival_time: The arrival time of the request.
lora_request: LoRA request.
prefix: The prefix of the prompt of the sequence group.
"""
def __init__(
@ -275,7 +283,6 @@ class SequenceGroup:
sampling_params: SamplingParams,
arrival_time: float,
lora_request: Optional[LoRARequest] = None,
prefix: Optional[Prefix] = None,
) -> None:
self.request_id = request_id
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
@ -286,7 +293,6 @@ class SequenceGroup:
first_token_time=None,
time_in_queue=None)
self.lora_request = lora_request
self.prefix: Optional[Prefix] = prefix
self.prompt_logprobs: Optional[PromptLogprobs] = None
self.state = SequenceGroupState()
@ -302,6 +308,10 @@ class SequenceGroup:
# We use the prompt of an arbitrary sequence.
return next(iter(self.seqs_dict.values())).data.prompt_token_ids
@property
def block_size(self) -> int:
return next(iter(self.seqs_dict.values())).block_size
@property
def lora_int_id(self) -> int:
return self.lora_request.lora_int_id if self.lora_request else 0
@ -408,7 +418,6 @@ class SequenceGroupMetadata:
numbers)
state: Internal state tied to this sequence group.
lora_request: LoRA request.
prefix: The prefix of the prompt of the sequence group.
"""
def __init__(
@ -419,7 +428,7 @@ class SequenceGroupMetadata:
sampling_params: SamplingParams,
block_tables: Dict[int, List[int]],
lora_request: Optional[LoRARequest] = None,
prefix: Optional[Prefix] = None,
computed_block_nums: Optional[List[int]] = None,
state: Optional[SequenceGroupState] = None,
) -> None:
self.request_id = request_id
@ -428,7 +437,7 @@ class SequenceGroupMetadata:
self.sampling_params = sampling_params
self.block_tables = block_tables
self.lora_request = lora_request
self.prefix = prefix
self.computed_block_nums = computed_block_nums
self.state = SequenceGroupState() if state is None else state
@property

View File

@ -145,33 +145,37 @@ class ModelRunner:
prompt_tokens = seq_data.get_token_ids()
prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len)
prefix_len = 0
prefix = seq_group_metadata.prefix
if prefix is not None and prefix.computed:
prefix_len = prefix.get_length()
prompt_tokens = prompt_tokens[prefix_len:]
prefix_block_tables.append(prefix.get_block_numbers())
computed_len = 0
# NOTE: This only works for oooooooxxx style attention.
computed_block_nums = seq_group_metadata.computed_block_nums
if computed_block_nums is not None and len(
computed_block_nums) > 0 and self.sliding_window is None:
# Prefix is not supported with sliding_window
computed_len = len(computed_block_nums) * self.block_size
prompt_tokens = prompt_tokens[computed_len:]
prefix_block_tables.append(computed_block_nums)
else:
prefix_block_tables.append([])
# actual prompt lens
context_lens.append(prefix_len)
subquery_lens.append(prompt_len - prefix_len)
context_lens.append(computed_len)
subquery_lens.append(prompt_len - computed_len)
input_tokens.append(prompt_tokens)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.append(
list(range(prefix_len, prefix_len + len(prompt_tokens))))
list(range(computed_len, computed_len + len(prompt_tokens))))
lora_id = seq_group_metadata.lora_int_id
if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request)
lora_index_mapping.append([lora_id] * (prompt_len - prefix_len))
lora_index_mapping.append([lora_id] * (prompt_len - computed_len))
lora_prompt_mapping.extend(
[lora_id] *
(prompt_len - prefix_len
(prompt_len - computed_len
if seq_group_metadata.sampling_params.prompt_logprobs else 1))
if seq_group_metadata.block_tables is None:
@ -190,11 +194,11 @@ class ModelRunner:
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx = 0
if self.sliding_window is not None:
assert prefix_len == 0, (
assert computed_len == 0, (
"Prefix caching is currently not supported with "
"sliding window attention")
start_idx = max(0, prompt_len - self.sliding_window)
for i in range(prefix_len, prompt_len):
for i in range(computed_len, prompt_len):
if i < start_idx:
slot_mapping[-1].append(_PAD_SLOT_ID)
continue