Add Automatic Prefix Caching (#2762)
Co-authored-by: ElizaWszola <eliza@neuralmagic.com> Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
parent
baee28c46c
commit
ce4f5a29fb
@ -73,21 +73,21 @@ def run_vllm(
|
||||
enforce_eager: bool,
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
enable_prefix_caching: bool,
|
||||
) -> float:
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
quantization=quantization,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
seed=seed,
|
||||
trust_remote_code=trust_remote_code,
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
enforce_eager=enforce_eager,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
device=device,
|
||||
)
|
||||
llm = LLM(model=model,
|
||||
tokenizer=tokenizer,
|
||||
quantization=quantization,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
seed=seed,
|
||||
trust_remote_code=trust_remote_code,
|
||||
dtype=dtype,
|
||||
max_model_len=max_model_len,
|
||||
enforce_eager=enforce_eager,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
device=device,
|
||||
enable_prefix_caching=enable_prefix_caching)
|
||||
|
||||
# Add the requests to the engine.
|
||||
for prompt, _, output_len in requests:
|
||||
@ -211,7 +211,8 @@ def main(args: argparse.Namespace):
|
||||
args.seed, args.n, args.use_beam_search,
|
||||
args.trust_remote_code, args.dtype,
|
||||
args.max_model_len, args.enforce_eager,
|
||||
args.kv_cache_dtype, args.device)
|
||||
args.kv_cache_dtype, args.device,
|
||||
args.enable_prefix_caching)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||
@ -302,6 +303,7 @@ if __name__ == "__main__":
|
||||
default="cuda",
|
||||
choices=["cuda"],
|
||||
help='device type for vLLM execution, supporting CUDA only currently.')
|
||||
parser.add_argument("--enable_prefix_caching", action='store_true')
|
||||
args = parser.parse_args()
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
|
@ -81,6 +81,10 @@ Below, you can find an explanation of every engine argument for vLLM:
|
||||
|
||||
Token block size for contiguous chunks of tokens.
|
||||
|
||||
.. option:: --enable-prefix-caching
|
||||
|
||||
Enables automatic prefix caching
|
||||
|
||||
.. option:: --seed <seed>
|
||||
|
||||
Random seed for operations.
|
||||
|
@ -37,20 +37,13 @@ for output in outputs:
|
||||
|
||||
print("-" * 80)
|
||||
|
||||
# -1 since the last token can change when concatenating prompts.
|
||||
prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1
|
||||
|
||||
# The llm.generate call will batch all prompts and send the batch at once if resources allow.
|
||||
# The prefix will only be cached after the first batch is processed, so we need to call generate once
|
||||
# to calculate the prefix and cache it.
|
||||
outputs = llm.generate(generating_prompts[0],
|
||||
sampling_params,
|
||||
prefix_pos=[prefix_pos])
|
||||
outputs = llm.generate(generating_prompts[0], sampling_params)
|
||||
|
||||
# Subsequent batches can leverage the cached prefix
|
||||
outputs = llm.generate(generating_prompts,
|
||||
sampling_params,
|
||||
prefix_pos=[prefix_pos] * len(generating_prompts))
|
||||
outputs = llm.generate(generating_prompts, sampling_params)
|
||||
|
||||
# Print the outputs. You should see the same outputs as before
|
||||
for output in outputs:
|
||||
|
@ -4,38 +4,73 @@ Run `pytest tests/prefix_caching/test_prefix_caching.py`.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
prefix = (
|
||||
"You are an expert school principal, skilled in effectively managing "
|
||||
"faculty and staff. Draft 10-15 questions for a potential first grade "
|
||||
"Head Teacher for my K-12, all-girls', independent school that emphasizes "
|
||||
"community, joyful discovery, and life-long learning. The candidate is "
|
||||
"coming in for a first-round panel interview for a 8th grade Math "
|
||||
"teaching role. They have 5 years of previous teaching experience "
|
||||
"as an assistant teacher at a co-ed, public school with experience "
|
||||
"in middle school math teaching. Based on these information, fulfill "
|
||||
"the following paragraph: ")
|
||||
from vllm.core.block_manager import BlockAllocator
|
||||
from vllm.utils import Device
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||
@pytest.mark.parametrize("max_tokens", [16])
|
||||
def test_prefix_caching(
|
||||
example_prompts,
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("num_blocks", [16])
|
||||
def test_block_allocator(
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
):
|
||||
llm = LLM(model=model)
|
||||
# -1 since the last token can change when concatenating prompts.
|
||||
prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1
|
||||
prompts = [prefix + prompt for prompt in example_prompts]
|
||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
||||
outputs_without_prefix = llm.generate(prompts, sampling_params)
|
||||
outputs_with_prefix = llm.generate(prompts,
|
||||
sampling_params,
|
||||
prefix_pos=[prefix_pos] * len(prompts))
|
||||
for output_without_prefix, output_with_prefix in zip(
|
||||
outputs_without_prefix, outputs_with_prefix):
|
||||
assert (output_without_prefix.outputs[0].token_ids ==
|
||||
output_with_prefix.outputs[0].token_ids)
|
||||
assert len(llm.llm_engine.scheduler.prefix_pool.prefixes) == 1
|
||||
block_hash = 1
|
||||
block_allocator = BlockAllocator(Device.CPU,
|
||||
block_size,
|
||||
num_blocks,
|
||||
enable_caching=True)
|
||||
|
||||
# Allocate two PysicalTokenBlocks with the same hash and check that they are the same PhysicalTokenBlock
|
||||
first_block = block_allocator.allocate(block_hash, 0)
|
||||
second_block = block_allocator.allocate(block_hash, 0)
|
||||
assert (first_block == second_block)
|
||||
assert (second_block.ref_count == 2)
|
||||
|
||||
# Free the first_block and confirm that the ref_count is correctly decremented on the second block
|
||||
block_allocator.free(first_block)
|
||||
assert (second_block.ref_count == 1)
|
||||
|
||||
# Free the second block
|
||||
block_allocator.free(second_block)
|
||||
|
||||
# Reallocate the first block and confirm that, even after the block had its ref_count go to 0, we still get the same block back
|
||||
first_block = block_allocator.allocate(block_hash, 0)
|
||||
assert (first_block == second_block)
|
||||
assert (first_block.block_hash == block_hash)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_blocks", [16])
|
||||
def test_eviction(num_blocks: int, ):
|
||||
block_size = 16
|
||||
block_allocator = BlockAllocator(Device.CPU,
|
||||
block_size,
|
||||
num_blocks,
|
||||
enable_caching=True)
|
||||
blocks = []
|
||||
|
||||
for i in range(num_blocks):
|
||||
# use i as the block_hash
|
||||
blocks.append(block_allocator.allocate(i, 0))
|
||||
|
||||
#Free all blocks
|
||||
for block in blocks:
|
||||
block_allocator.free(block)
|
||||
|
||||
# Allocate a new block and confirm that it's the first block freed. I.E The Least Recently Used block
|
||||
new_block_hash = block_size
|
||||
new_block = block_allocator.allocate(new_block_hash, 0)
|
||||
assert (new_block == blocks[0])
|
||||
assert (new_block.block_hash == new_block_hash)
|
||||
|
||||
# Reallocate the second in blocks to remove it from the free list
|
||||
realloc_block_hash = 1
|
||||
realloc_block = block_allocator.allocate(realloc_block_hash, 0)
|
||||
assert (realloc_block == blocks[realloc_block_hash])
|
||||
assert (realloc_block.block_hash == realloc_block_hash)
|
||||
|
||||
# Allocate a new block and confirm that it's not the realloc_block, since the realloc_block shouldn't be in the free list
|
||||
new_block_hash = block_size + 1
|
||||
new_block = block_allocator.allocate(new_block_hash, 0)
|
||||
assert (realloc_block != new_block)
|
||||
assert (new_block.block_hash == new_block_hash)
|
||||
assert (new_block.block_number == 2)
|
||||
|
76
tests/test_cache_block_hashing.py
Normal file
76
tests/test_cache_block_hashing.py
Normal file
@ -0,0 +1,76 @@
|
||||
"""Test hashing of cache blocks.
|
||||
|
||||
Run `pytest tests/test_cache_block_hashing.py`.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from vllm.transformers_utils.tokenizer import TokenizerGroup
|
||||
from vllm.sequence import Sequence
|
||||
|
||||
# Make two prefixes with different first blocks.
|
||||
prefix_start = [("You are an expert"), ("You are a")]
|
||||
prefix_common = (
|
||||
" school principal, skilled in effectively managing "
|
||||
"faculty and staff. Draft 10-15 questions for a potential first grade "
|
||||
"Head Teacher for my K-12, all-girls', independent school that emphasizes "
|
||||
"community, joyful discovery, and life-long learning. The candidate is "
|
||||
"coming in for a first-round panel interview for a 8th grade Math "
|
||||
"teaching role. They have 5 years of previous teaching experience "
|
||||
"as an assistant teacher at a co-ed, public school with experience "
|
||||
"in middle school math teaching. Based on this, fulfill "
|
||||
"the following: ")
|
||||
prefixes = [start + prefix_common for start in prefix_start]
|
||||
|
||||
# Sample prompts.
|
||||
sample_prompts = [
|
||||
"Hello, my name is", "The president of the United States is",
|
||||
"The capital of France is", "The future of AI is"
|
||||
]
|
||||
|
||||
|
||||
# Helper function.
|
||||
def flatten_2d(li):
|
||||
return [lss for ls in li for lss in ls]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||
@pytest.mark.parametrize("block_size", [16])
|
||||
@pytest.mark.parametrize("max_num_seqs", [256])
|
||||
def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int):
|
||||
|
||||
tokenizer = TokenizerGroup(
|
||||
tokenizer_id="facebook/opt-125m",
|
||||
enable_lora=False,
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_input_length=None,
|
||||
)
|
||||
|
||||
hashes = []
|
||||
|
||||
for prefix in prefixes:
|
||||
hashes.append([])
|
||||
prompts = [prefix + prompt for prompt in sample_prompts]
|
||||
seq_id = 0
|
||||
for prompt in prompts:
|
||||
hashes[-1].append([])
|
||||
prompt_token_ids = tokenizer.encode(prompt)
|
||||
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size)
|
||||
|
||||
num_blocks = len(prompt_token_ids) // block_size
|
||||
for idx in range(num_blocks):
|
||||
hashes[-1][-1].append(seq.hash_of_block(idx))
|
||||
|
||||
seq_id += 1
|
||||
|
||||
# Check that hashes made with two prefixes with different first blocks are
|
||||
# different everywhere.
|
||||
for hash0, hash1 in zip(flatten_2d(hashes[0]), flatten_2d(hashes[1])):
|
||||
assert (hash0 != hash1)
|
||||
|
||||
# Check that hashes of different prompts made with the same prefix are the
|
||||
# same until the hashes that contain the prompt.
|
||||
for hash_pref in hashes:
|
||||
same_hashes = [tuple(h[:-1]) for h in hash_pref]
|
||||
different_hashes = [h[-1] for h in hash_pref]
|
||||
assert (len(set(same_hashes)) == 1)
|
||||
assert (len(set(different_hashes)) == len(different_hashes))
|
@ -5,6 +5,8 @@ from vllm.utils import Device
|
||||
|
||||
_BLANK_TOKEN_ID = -1
|
||||
|
||||
DEFAULT_LAST_ACCESSED_TIME = -1
|
||||
|
||||
|
||||
class LogicalTokenBlock:
|
||||
"""A block that stores a contiguous chunk of tokens from left to right.
|
||||
@ -55,17 +57,27 @@ class PhysicalTokenBlock:
|
||||
device: Device,
|
||||
block_number: int,
|
||||
block_size: int,
|
||||
block_hash: int,
|
||||
num_hashed_tokens: int,
|
||||
) -> None:
|
||||
self.device = device
|
||||
self.block_number = block_number
|
||||
self.block_size = block_size
|
||||
self.block_hash = block_hash
|
||||
self.num_hashed_tokens = num_hashed_tokens
|
||||
|
||||
self.ref_count = 0
|
||||
self.last_accessed = DEFAULT_LAST_ACCESSED_TIME
|
||||
|
||||
self.computed = False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f'PhysicalTokenBlock(device={self.device}, '
|
||||
f'block_number={self.block_number}, '
|
||||
f'ref_count={self.ref_count})')
|
||||
f'num_hashed_tokens={self.num_hashed_tokens}, '
|
||||
f'ref_count={self.ref_count}, '
|
||||
f'last_accessed={self.last_accessed}, '
|
||||
f'computed={self.computed})')
|
||||
|
||||
|
||||
# Mapping: logical block number -> physical block.
|
||||
|
@ -303,12 +303,14 @@ class CacheConfig:
|
||||
swap_space: int,
|
||||
cache_dtype: str,
|
||||
sliding_window: Optional[int] = None,
|
||||
enable_prefix_caching: bool = False,
|
||||
) -> None:
|
||||
self.block_size = block_size
|
||||
self.gpu_memory_utilization = gpu_memory_utilization
|
||||
self.swap_space_bytes = swap_space * _GB
|
||||
self.cache_dtype = cache_dtype
|
||||
self.sliding_window = sliding_window
|
||||
self.enable_prefix_caching = enable_prefix_caching
|
||||
self._verify_args()
|
||||
self._verify_cache_dtype()
|
||||
|
||||
|
@ -1,10 +1,13 @@
|
||||
"""A block manager that manages token blocks."""
|
||||
import enum
|
||||
from itertools import count
|
||||
from os.path import commonprefix
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
from vllm.block import BlockTable, PhysicalTokenBlock
|
||||
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||
from vllm.utils import Device
|
||||
from vllm.core.evictor import Evictor, EvictionPolicy, make_evictor
|
||||
|
||||
|
||||
class BlockAllocator:
|
||||
@ -15,29 +18,68 @@ class BlockAllocator:
|
||||
the reference count becomes zero, the block is added back to the free list.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: Device,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
) -> None:
|
||||
def __init__(self,
|
||||
device: Device,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
|
||||
enable_caching: bool = False) -> None:
|
||||
self.device = device
|
||||
self.block_size = block_size
|
||||
self.num_blocks = num_blocks
|
||||
self.enable_caching = enable_caching
|
||||
|
||||
# Initialize the free blocks.
|
||||
self.free_blocks: BlockTable = []
|
||||
for i in range(num_blocks):
|
||||
block = PhysicalTokenBlock(device=device,
|
||||
block_number=i,
|
||||
block_size=block_size)
|
||||
self.free_blocks.append(block)
|
||||
self.current_num_blocks = 0
|
||||
self.cached_blocks: Dict[int, PhysicalTokenBlock] = {}
|
||||
|
||||
def allocate(self) -> PhysicalTokenBlock:
|
||||
if not self.free_blocks:
|
||||
raise ValueError("Out of memory! No free blocks are available.")
|
||||
block = self.free_blocks.pop()
|
||||
block.ref_count = 1
|
||||
# Switch over to FIFO eviction when caching is disabled
|
||||
if not self.enable_caching:
|
||||
eviction_policy = EvictionPolicy.FIFO
|
||||
self.evictor: Evictor = make_evictor(eviction_policy)
|
||||
|
||||
self.default_hash_ctr = count()
|
||||
|
||||
def allocate_block(self, block_hash: int,
|
||||
num_hashed_tokens: int) -> PhysicalTokenBlock:
|
||||
if self.current_num_blocks == self.num_blocks:
|
||||
block = self.evictor.evict()
|
||||
block.block_hash = block_hash
|
||||
block.num_hashed_tokens = num_hashed_tokens
|
||||
return block
|
||||
block = PhysicalTokenBlock(device=self.device,
|
||||
block_number=self.current_num_blocks,
|
||||
block_size=self.block_size,
|
||||
block_hash=block_hash,
|
||||
num_hashed_tokens=num_hashed_tokens)
|
||||
self.current_num_blocks += 1
|
||||
return block
|
||||
|
||||
def allocate(self,
|
||||
block_hash: Optional[int] = None,
|
||||
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
|
||||
# If caching is disabled, just allocate a new block and return it
|
||||
if not self.enable_caching:
|
||||
block = self.allocate_block(next(self.default_hash_ctr),
|
||||
num_hashed_tokens)
|
||||
block.ref_count += 1
|
||||
return block
|
||||
|
||||
if block_hash is None:
|
||||
block_hash = next(self.default_hash_ctr)
|
||||
if block_hash in self.evictor:
|
||||
assert block_hash not in self.cached_blocks
|
||||
block = self.evictor.remove(block_hash)
|
||||
assert block.ref_count == 0
|
||||
self.cached_blocks[block_hash] = block
|
||||
block.ref_count += 1
|
||||
assert block.block_hash == block_hash
|
||||
return block
|
||||
if block_hash not in self.cached_blocks:
|
||||
self.cached_blocks[block_hash] = self.allocate_block(
|
||||
block_hash, num_hashed_tokens)
|
||||
block = self.cached_blocks[block_hash]
|
||||
assert block.block_hash == block_hash
|
||||
block.ref_count += 1
|
||||
return block
|
||||
|
||||
def free(self, block: PhysicalTokenBlock) -> None:
|
||||
@ -45,10 +87,27 @@ class BlockAllocator:
|
||||
raise ValueError(f"Double free! {block} is already freed.")
|
||||
block.ref_count -= 1
|
||||
if block.ref_count == 0:
|
||||
self.free_blocks.append(block)
|
||||
assert block.block_hash not in self.evictor
|
||||
self.evictor.add(block)
|
||||
|
||||
# If caching is enabled, remove the block from the cached_blocks
|
||||
if self.enable_caching:
|
||||
del self.cached_blocks[block.block_hash]
|
||||
|
||||
def get_num_free_blocks(self) -> int:
|
||||
return len(self.free_blocks)
|
||||
return self.num_blocks - self.current_num_blocks + self.evictor.num_blocks
|
||||
|
||||
def contains_block(self, block_hash: int) -> bool:
|
||||
return block_hash in self.cached_blocks or block_hash in self.evictor
|
||||
|
||||
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
|
||||
# If caching is enabled, update the hash of block and the cached_blocks dictionary.
|
||||
if self.enable_caching:
|
||||
assert not self.contains_block(block_hash)
|
||||
old_hash = block.block_hash
|
||||
block.block_hash = block_hash
|
||||
del self.cached_blocks[old_hash]
|
||||
self.cached_blocks[block_hash] = block
|
||||
|
||||
|
||||
class AllocStatus(enum.Enum):
|
||||
@ -75,6 +134,7 @@ class BlockSpaceManager:
|
||||
num_cpu_blocks: int,
|
||||
watermark: float = 0.01,
|
||||
sliding_window: Optional[int] = None,
|
||||
enable_caching: bool = False,
|
||||
) -> None:
|
||||
self.block_size = block_size
|
||||
self.num_total_gpu_blocks = num_gpu_blocks
|
||||
@ -89,11 +149,17 @@ class BlockSpaceManager:
|
||||
self.watermark = watermark
|
||||
assert watermark >= 0.0
|
||||
|
||||
self.enable_caching = enable_caching
|
||||
|
||||
self.watermark_blocks = int(watermark * num_gpu_blocks)
|
||||
self.gpu_allocator = BlockAllocator(Device.GPU, block_size,
|
||||
num_gpu_blocks)
|
||||
self.cpu_allocator = BlockAllocator(Device.CPU, block_size,
|
||||
num_cpu_blocks)
|
||||
self.gpu_allocator = BlockAllocator(Device.GPU,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
enable_caching=enable_caching)
|
||||
self.cpu_allocator = BlockAllocator(Device.CPU,
|
||||
block_size,
|
||||
num_cpu_blocks,
|
||||
enable_caching=enable_caching)
|
||||
# Mapping: seq_id -> BlockTable.
|
||||
self.block_tables: Dict[int, BlockTable] = {}
|
||||
|
||||
@ -103,9 +169,6 @@ class BlockSpaceManager:
|
||||
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
|
||||
num_required_blocks = len(seq.logical_token_blocks)
|
||||
|
||||
if seq_group.prefix is not None and seq_group.prefix.allocated:
|
||||
num_required_blocks -= seq_group.prefix.get_num_blocks()
|
||||
|
||||
if self.block_sliding_window is not None:
|
||||
num_required_blocks = min(num_required_blocks,
|
||||
self.block_sliding_window)
|
||||
@ -129,36 +192,16 @@ class BlockSpaceManager:
|
||||
num_prompt_blocks = len(seq.logical_token_blocks)
|
||||
|
||||
block_table: BlockTable = []
|
||||
prefix_block_table: BlockTable = []
|
||||
num_prefix_blocks = 0
|
||||
|
||||
prefix = seq_group.prefix
|
||||
if prefix is not None and prefix.allocated:
|
||||
# Prefix has already been allocated. Use the existing block table.
|
||||
num_prompt_blocks -= prefix.get_num_blocks()
|
||||
for block in prefix.block_table:
|
||||
block.ref_count += seq_group.num_seqs()
|
||||
block_table.append(block)
|
||||
|
||||
for logical_idx in range(num_prompt_blocks):
|
||||
if (self.block_sliding_window is not None
|
||||
and logical_idx >= self.block_sliding_window):
|
||||
block = block_table[logical_idx % self.block_sliding_window]
|
||||
else:
|
||||
block = self.gpu_allocator.allocate()
|
||||
# Set the reference counts of the token blocks.
|
||||
block.ref_count = seq_group.num_seqs()
|
||||
block = self.gpu_allocator.allocate(
|
||||
seq.hash_of_block(logical_idx),
|
||||
seq.num_hashed_tokens_of_block(logical_idx))
|
||||
block_table.append(block)
|
||||
|
||||
if prefix is not None and not prefix.allocated:
|
||||
# Allocate blocks for the prefix, we will compute the prefix's
|
||||
# KV cache in this run.
|
||||
num_prefix_blocks = prefix.get_num_blocks()
|
||||
prefix_block_table = block_table[:num_prefix_blocks]
|
||||
for block in prefix_block_table:
|
||||
block.ref_count += 1
|
||||
prefix.set_block_table(prefix_block_table)
|
||||
|
||||
# Assign the block table for each sequence.
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
|
||||
self.block_tables[seq.seq_id] = block_table.copy()
|
||||
@ -170,12 +213,72 @@ class BlockSpaceManager:
|
||||
num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
||||
return num_seqs <= num_free_gpu_blocks
|
||||
|
||||
def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]:
|
||||
def _promote_last_block(
|
||||
self,
|
||||
seq: Sequence,
|
||||
last_block: PhysicalTokenBlock,
|
||||
) -> PhysicalTokenBlock:
|
||||
# Compute a new hash for the block so that it can be shared by other Sequences
|
||||
new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
|
||||
|
||||
# if new_hash is already in the cached table, then free last_block and return the cached version
|
||||
if self.gpu_allocator.contains_block(new_hash):
|
||||
self.gpu_allocator.free(last_block)
|
||||
return self.gpu_allocator.allocate(new_hash)
|
||||
else:
|
||||
self.gpu_allocator.update_hash(new_hash, last_block)
|
||||
return last_block
|
||||
|
||||
def _is_last_block_full(
|
||||
self,
|
||||
seq: Sequence,
|
||||
) -> bool:
|
||||
token_ids_len = len(seq.data.get_token_ids())
|
||||
return token_ids_len > 0 and token_ids_len % seq.block_size == 0
|
||||
|
||||
def _is_last_block(
|
||||
self,
|
||||
seq: Sequence,
|
||||
index: int,
|
||||
) -> bool:
|
||||
return index == len(seq.logical_token_blocks) - 1
|
||||
|
||||
def _maybe_promote_last_block(
|
||||
self,
|
||||
seq: Sequence,
|
||||
last_block: PhysicalTokenBlock,
|
||||
) -> PhysicalTokenBlock:
|
||||
if self._is_last_block_full(seq):
|
||||
return self._promote_last_block(seq, last_block)
|
||||
else:
|
||||
return last_block
|
||||
|
||||
def _allocate_last_physical_block(
|
||||
self,
|
||||
seq: Sequence,
|
||||
) -> PhysicalTokenBlock:
|
||||
block_hash: Optional[int] = None
|
||||
if (self._is_last_block_full(seq)):
|
||||
block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
|
||||
num_hashed_tokens = seq.num_hashed_tokens_of_block(
|
||||
len(seq.logical_token_blocks) - 1)
|
||||
new_block = self.gpu_allocator.allocate(block_hash, num_hashed_tokens)
|
||||
if block_hash is None:
|
||||
assert new_block.ref_count == 1
|
||||
return new_block
|
||||
|
||||
def append_slot(
|
||||
self,
|
||||
seq: Sequence,
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""Allocate a physical slot for a new token."""
|
||||
logical_blocks = seq.logical_token_blocks
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
|
||||
# If we need to allocate a new physical block
|
||||
if len(block_table) < len(logical_blocks):
|
||||
# Currently this code only supports adding one physical block
|
||||
assert len(block_table) == len(logical_blocks) - 1
|
||||
|
||||
if (self.block_sliding_window
|
||||
and len(block_table) >= self.block_sliding_window):
|
||||
# reuse a block
|
||||
@ -184,8 +287,8 @@ class BlockSpaceManager:
|
||||
else:
|
||||
# The sequence has a new logical block.
|
||||
# Allocate a new physical block.
|
||||
block = self.gpu_allocator.allocate()
|
||||
block_table.append(block)
|
||||
new_block = self._allocate_last_physical_block(seq)
|
||||
block_table.append(new_block)
|
||||
return None
|
||||
|
||||
# We want to append the token to the last physical block.
|
||||
@ -193,11 +296,15 @@ class BlockSpaceManager:
|
||||
assert last_block.device == Device.GPU
|
||||
if last_block.ref_count == 1:
|
||||
# Not shared with other sequences. Appendable.
|
||||
# If the last block is now complete, promote it to a full block so that it can be shared
|
||||
new_block = self._maybe_promote_last_block(seq, last_block)
|
||||
block_table[-1] = new_block
|
||||
return None
|
||||
else:
|
||||
# The last block is shared with other sequences.
|
||||
# Copy on Write: Allocate a new block and copy the tokens.
|
||||
new_block = self.gpu_allocator.allocate()
|
||||
new_block = self._allocate_last_physical_block(seq)
|
||||
|
||||
block_table[-1] = new_block
|
||||
self.gpu_allocator.free(last_block)
|
||||
return last_block.block_number, new_block.block_number
|
||||
@ -233,25 +340,18 @@ class BlockSpaceManager:
|
||||
|
||||
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
||||
# CPU block -> GPU block.
|
||||
if seq_group.prefix is not None:
|
||||
# make sure to swap in the prefix first
|
||||
assert seq_group.prefix.allocated and seq_group.prefix.computed
|
||||
|
||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
||||
new_block_table: BlockTable = []
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
if seq_group.prefix is not None:
|
||||
for block in seq_group.prefix.block_table:
|
||||
new_block_table.append(block)
|
||||
block.ref_count += 1
|
||||
|
||||
for cpu_block in block_table:
|
||||
if cpu_block in mapping:
|
||||
gpu_block = mapping[cpu_block]
|
||||
gpu_block.ref_count += 1
|
||||
else:
|
||||
gpu_block = self.gpu_allocator.allocate()
|
||||
gpu_block = self.gpu_allocator.allocate(
|
||||
cpu_block.block_hash, cpu_block.num_hashed_tokens)
|
||||
mapping[cpu_block] = gpu_block
|
||||
new_block_table.append(gpu_block)
|
||||
# Free the CPU block swapped in to GPU.
|
||||
@ -276,17 +376,12 @@ class BlockSpaceManager:
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
|
||||
for gpu_block in block_table:
|
||||
if (seq_group.prefix is not None
|
||||
and gpu_block in seq_group.prefix.block_table):
|
||||
# NOTE: We do not swap out the prefix blocks for now.
|
||||
self.gpu_allocator.free(gpu_block)
|
||||
continue
|
||||
|
||||
if gpu_block in mapping:
|
||||
cpu_block = mapping[gpu_block]
|
||||
cpu_block.ref_count += 1
|
||||
else:
|
||||
cpu_block = self.cpu_allocator.allocate()
|
||||
cpu_block = self.cpu_allocator.allocate(
|
||||
gpu_block.block_hash, gpu_block.num_hashed_tokens)
|
||||
mapping[gpu_block] = cpu_block
|
||||
new_block_table.append(cpu_block)
|
||||
# Free the GPU block swapped out to CPU.
|
||||
@ -328,3 +423,49 @@ class BlockSpaceManager:
|
||||
|
||||
def get_num_free_cpu_blocks(self) -> int:
|
||||
return self.cpu_allocator.get_num_free_blocks()
|
||||
|
||||
def access_all_blocks_in_seq(
|
||||
self,
|
||||
seq: Sequence,
|
||||
access_time: float,
|
||||
) -> None:
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
for block in block_table:
|
||||
block.last_accessed = access_time
|
||||
|
||||
def compute_last_full_block_in_seq(self, seq: Sequence):
|
||||
if seq.seq_id not in self.block_tables:
|
||||
return
|
||||
max_full_block = seq.get_len() // seq.block_size - 1
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
if max_full_block == -1:
|
||||
return
|
||||
block_table[max_full_block].computed = True
|
||||
|
||||
def get_all_block_ids_till_computed(self, seq: Sequence) -> List[int]:
|
||||
if seq.seq_id not in self.block_tables:
|
||||
return []
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
for block_idx in reversed(range(len(block_table))):
|
||||
if block_table[block_idx].computed:
|
||||
return [b.block_number for b in block_table[:block_idx + 1]]
|
||||
return []
|
||||
|
||||
# Can return non-empty result only with prefix caching enabled.
|
||||
def get_common_computed_block_ids(self,
|
||||
seq_group: SequenceGroup) -> List[int]:
|
||||
if not self.enable_caching:
|
||||
return []
|
||||
|
||||
ids_list = [
|
||||
self.get_all_block_ids_till_computed(seq)
|
||||
for seq in iter(seq_group.seqs_dict.values())
|
||||
]
|
||||
return commonprefix([ids for ids in ids_list if ids != []])
|
||||
|
||||
# We only mark the last full block because with prefix caching,
|
||||
# all blocks until the marked one are guaranteed to be computed.
|
||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
|
||||
if self.enable_caching:
|
||||
for seq in seq_group.seqs_dict.values():
|
||||
self.compute_last_full_block_in_seq(seq)
|
||||
|
161
vllm/core/evictor.py
Normal file
161
vllm/core/evictor.py
Normal file
@ -0,0 +1,161 @@
|
||||
import enum
|
||||
from typing import Dict, List, Optional
|
||||
from abc import ABC, abstractmethod, abstractproperty
|
||||
|
||||
from vllm.block import PhysicalTokenBlock
|
||||
|
||||
|
||||
class EvictionPolicy(enum.Enum):
|
||||
"""Enum for eviction policy used by make_evictor to instantiate the correct
|
||||
Evictor subclass.
|
||||
"""
|
||||
LRU = enum.auto()
|
||||
FIFO = enum.auto()
|
||||
|
||||
|
||||
class Evictor(ABC):
|
||||
"""The Evictor subclasses should be used by the BlockAllocator class to
|
||||
handle eviction of freed PhysicalTokenBlocks.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __contains__(self, block_hash: int) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def evict(self) -> PhysicalTokenBlock:
|
||||
"""Runs the eviction algorithm and returns the evicted block"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add(self, block: PhysicalTokenBlock):
|
||||
"""Adds block to the evictor, making it a candidate for eviction"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove(self, block_hash: int) -> PhysicalTokenBlock:
|
||||
"""Simply removes the block with the hash value block_hash from the
|
||||
evictor. Caller is responsible for making sure that block_hash is contained
|
||||
in the evictor before calling remove. Should be used to "bring back" blocks
|
||||
that have been freed but not evicted yet.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractproperty
|
||||
def num_blocks(self) -> int:
|
||||
pass
|
||||
|
||||
|
||||
class LRUEvictor(Evictor):
|
||||
"""Evicts in a least-recently-used order using the last_accessed timestamp
|
||||
that's recorded in the PhysicalTokenBlock. If there are multiple blocks with
|
||||
the same last_accessed time, then the one with the largest num_hashed_tokens
|
||||
will be evicted. If two blocks each have the lowest last_accessed time and
|
||||
highest num_hashed_tokens value, then one will be chose arbitrarily
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.free_table: Dict[int, PhysicalTokenBlock] = {}
|
||||
|
||||
def __contains__(self, block_hash: int) -> bool:
|
||||
return block_hash in self.free_table
|
||||
|
||||
# TODO: The performance of this evict function can be optimized further.
|
||||
def evict(self) -> PhysicalTokenBlock:
|
||||
free_blocks: List[PhysicalTokenBlock] = list(self.free_table.values())
|
||||
if len(free_blocks) == 0:
|
||||
raise ValueError("No usable cache memory left")
|
||||
|
||||
# Find lowest timestamp
|
||||
lowest_timestamp = free_blocks[0].last_accessed
|
||||
for block in free_blocks:
|
||||
if block.last_accessed < lowest_timestamp:
|
||||
lowest_timestamp = block.last_accessed
|
||||
|
||||
# Find all blocks with the lowest timestamp
|
||||
least_recent: List[PhysicalTokenBlock] = []
|
||||
for block in free_blocks:
|
||||
if block.last_accessed == lowest_timestamp:
|
||||
least_recent.append(block)
|
||||
|
||||
# Find highest prefix count per block
|
||||
highest_num_hashed_tokens = 0
|
||||
for block in least_recent:
|
||||
if block.num_hashed_tokens > highest_num_hashed_tokens:
|
||||
highest_num_hashed_tokens = block.num_hashed_tokens
|
||||
|
||||
evicted_block: Optional[PhysicalTokenBlock] = None
|
||||
|
||||
# Find the first block with the lowest timestamp
|
||||
for block in least_recent:
|
||||
if block.num_hashed_tokens == highest_num_hashed_tokens:
|
||||
evicted_block = block
|
||||
break
|
||||
|
||||
assert evicted_block is not None
|
||||
|
||||
del self.free_table[evicted_block.block_hash]
|
||||
|
||||
evicted_block.computed = False
|
||||
return evicted_block
|
||||
|
||||
def add(self, block: PhysicalTokenBlock):
|
||||
self.free_table[block.block_hash] = block
|
||||
|
||||
def remove(self, block_hash: int) -> PhysicalTokenBlock:
|
||||
if block_hash not in self.free_table:
|
||||
raise ValueError(
|
||||
"Attempting to remove block that's not in the evictor")
|
||||
block: PhysicalTokenBlock = self.free_table[block_hash]
|
||||
del self.free_table[block_hash]
|
||||
return block
|
||||
|
||||
@property
|
||||
def num_blocks(self) -> int:
|
||||
return len(self.free_table)
|
||||
|
||||
|
||||
class RandomEvictor(Evictor):
|
||||
"""Evicts in a first-in-first-out order"""
|
||||
|
||||
def __init__(self):
|
||||
self.free_table: Dict[int, PhysicalTokenBlock] = {}
|
||||
|
||||
def __contains__(self, block_hash: int) -> bool:
|
||||
return block_hash in self.free_table
|
||||
|
||||
def evict(self) -> PhysicalTokenBlock:
|
||||
if len(self.free_table) == 0:
|
||||
raise ValueError("No usable cache memory left")
|
||||
evicted_block = next(iter(self.free_table.values()))
|
||||
evicted_block.computed = False
|
||||
del self.free_table[evicted_block.block_hash]
|
||||
return evicted_block
|
||||
|
||||
def add(self, block: PhysicalTokenBlock):
|
||||
self.free_table[block.block_hash] = block
|
||||
|
||||
def remove(self, block_hash: int) -> PhysicalTokenBlock:
|
||||
if block_hash not in self.free_table:
|
||||
raise ValueError(
|
||||
"Attempting to remove block that's not in the evictor")
|
||||
block: PhysicalTokenBlock = self.free_table[block_hash]
|
||||
del self.free_table[block_hash]
|
||||
return block
|
||||
|
||||
@property
|
||||
def num_blocks(self) -> int:
|
||||
return len(self.free_table)
|
||||
|
||||
|
||||
def make_evictor(eviction_policy: EvictionPolicy) -> Evictor:
|
||||
if eviction_policy == EvictionPolicy.LRU:
|
||||
return LRUEvictor()
|
||||
elif eviction_policy == EvictionPolicy.FIFO:
|
||||
return RandomEvictor()
|
||||
else:
|
||||
raise ValueError(f"Unknown cache eviction policy: {eviction_policy}")
|
@ -10,7 +10,6 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
||||
SequenceGroupMetadata, SequenceStatus)
|
||||
from vllm.prefix import PrefixPool
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -95,10 +94,8 @@ class Scheduler:
|
||||
block_size=self.cache_config.block_size,
|
||||
num_gpu_blocks=self.cache_config.num_gpu_blocks,
|
||||
num_cpu_blocks=self.cache_config.num_cpu_blocks,
|
||||
sliding_window=self.cache_config.sliding_window)
|
||||
|
||||
# Create the prefix pool to cache the prefixes.
|
||||
self.prefix_pool = PrefixPool(self.cache_config.block_size)
|
||||
sliding_window=self.cache_config.sliding_window,
|
||||
enable_caching=self.cache_config.enable_prefix_caching)
|
||||
|
||||
# Sequence groups in the WAITING state.
|
||||
self.waiting: Deque[SequenceGroup] = deque()
|
||||
@ -374,10 +371,12 @@ class Scheduler:
|
||||
|
||||
seq_data: Dict[int, SequenceData] = {}
|
||||
block_tables: Dict[int, List[int]] = {}
|
||||
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
seq_id = seq.seq_id
|
||||
seq_data[seq_id] = seq.data
|
||||
block_tables[seq_id] = self.block_manager.get_block_table(seq)
|
||||
self.block_manager.access_all_blocks_in_seq(seq, now)
|
||||
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=seq_group.request_id,
|
||||
@ -386,7 +385,8 @@ class Scheduler:
|
||||
sampling_params=seq_group.sampling_params,
|
||||
block_tables=block_tables,
|
||||
lora_request=seq_group.lora_request,
|
||||
prefix=seq_group.prefix,
|
||||
computed_block_nums=self.block_manager.
|
||||
get_common_computed_block_ids(seq_group),
|
||||
state=seq_group.state,
|
||||
)
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
@ -496,3 +496,6 @@ class Scheduler:
|
||||
blocks_to_swap_out.update(mapping)
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
seq.status = SequenceStatus.SWAPPED
|
||||
|
||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
|
||||
self.block_manager.mark_blocks_as_computed(seq_group)
|
||||
|
@ -25,6 +25,7 @@ class EngineArgs:
|
||||
tensor_parallel_size: int = 1
|
||||
max_parallel_loading_workers: Optional[int] = None
|
||||
block_size: int = 16
|
||||
enable_prefix_caching: bool = False
|
||||
swap_space: int = 4 # GiB
|
||||
gpu_memory_utilization: float = 0.90
|
||||
max_num_batched_tokens: Optional[int] = None
|
||||
@ -173,6 +174,11 @@ class EngineArgs:
|
||||
default=EngineArgs.block_size,
|
||||
choices=[8, 16, 32, 128],
|
||||
help='token block size')
|
||||
|
||||
parser.add_argument('--enable-prefix-caching',
|
||||
action='store_true',
|
||||
help='Enables automatic prefix caching')
|
||||
|
||||
parser.add_argument('--seed',
|
||||
type=int,
|
||||
default=EngineArgs.seed,
|
||||
@ -293,7 +299,8 @@ class EngineArgs:
|
||||
cache_config = CacheConfig(self.block_size,
|
||||
self.gpu_memory_utilization,
|
||||
self.swap_space, self.kv_cache_dtype,
|
||||
model_config.get_sliding_window())
|
||||
model_config.get_sliding_window(),
|
||||
self.enable_prefix_caching)
|
||||
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size,
|
||||
self.worker_use_ray,
|
||||
|
@ -225,7 +225,6 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prefix_pos: Optional[int] = None,
|
||||
) -> None:
|
||||
if lora_request is not None and not self.lora_config:
|
||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||
@ -245,7 +244,6 @@ class _AsyncLLMEngine(LLMEngine):
|
||||
sampling_params=sampling_params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
prefix_pos=prefix_pos,
|
||||
)
|
||||
|
||||
async def _run_workers_async(
|
||||
@ -422,7 +420,6 @@ class AsyncLLMEngine:
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prefix_pos: Optional[int] = None,
|
||||
) -> AsyncStream:
|
||||
if self.log_requests:
|
||||
shortened_prompt = prompt
|
||||
@ -435,7 +432,6 @@ class AsyncLLMEngine:
|
||||
max_log_len]
|
||||
logger.info(f"Received request {request_id}: "
|
||||
f"prompt: {shortened_prompt!r}, "
|
||||
f"prefix_pos: {prefix_pos},"
|
||||
f"sampling_params: {sampling_params}, "
|
||||
f"prompt_token_ids: {shortened_token_ids}, "
|
||||
f"lora_request: {lora_request}.")
|
||||
@ -472,8 +468,7 @@ class AsyncLLMEngine:
|
||||
sampling_params=sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
prefix_pos=prefix_pos)
|
||||
lora_request=lora_request)
|
||||
|
||||
return stream
|
||||
|
||||
@ -484,7 +479,6 @@ class AsyncLLMEngine:
|
||||
request_id: str,
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prefix_pos: Optional[int] = None,
|
||||
) -> AsyncIterator[RequestOutput]:
|
||||
"""Generate outputs for a request.
|
||||
|
||||
@ -500,11 +494,6 @@ class AsyncLLMEngine:
|
||||
prompt_token_ids: The token IDs of the prompt. If None, we
|
||||
use the tokenizer to convert the prompts to token IDs.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
prefix_pos: If not None, we use the given position as the prefix
|
||||
position for each prompt. We will cache the prefix's KV
|
||||
cache and reuse it for the next request with the same prefix.
|
||||
This is an experimental feature, and may be replaced with
|
||||
automatic prefix caching in the future.
|
||||
|
||||
Yields:
|
||||
The output `RequestOutput` objects from the LLMEngine for the
|
||||
@ -565,7 +554,6 @@ class AsyncLLMEngine:
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
prefix_pos=prefix_pos,
|
||||
)
|
||||
|
||||
async for request_output in stream:
|
||||
|
@ -415,7 +415,6 @@ class LLMEngine:
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prefix_pos: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Add a request to the engine's request pool.
|
||||
|
||||
@ -432,11 +431,6 @@ class LLMEngine:
|
||||
use the tokenizer to convert the prompts to token IDs.
|
||||
arrival_time: The arrival time of the request. If None, we use
|
||||
the current monotonic time.
|
||||
prefix_pos: If not None, we use the given position as the prefix
|
||||
position for each prompt. We will cache the prefix's KV
|
||||
cache and reuse it for the next request with the same prefix.
|
||||
This is an experimental feature, and may be replaced with
|
||||
automatic prefix caching in the future.
|
||||
|
||||
Details:
|
||||
- Set arrival_time to the current time if it is None.
|
||||
@ -479,18 +473,13 @@ class LLMEngine:
|
||||
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
|
||||
lora_request)
|
||||
|
||||
# Check whether the input specifies prefix
|
||||
prefix = self.scheduler.prefix_pool.add_or_get_prefix(
|
||||
prompt_token_ids[:prefix_pos], lora_request.lora_int_id
|
||||
if lora_request else 0) if prefix_pos is not None else None
|
||||
|
||||
# Defensive copy of SamplingParams, which are used by the sampler,
|
||||
# this doesn't deep-copy LogitsProcessor objects
|
||||
sampling_params = sampling_params.clone()
|
||||
|
||||
# Create the sequence group.
|
||||
seq_group = SequenceGroup(request_id, [seq], sampling_params,
|
||||
arrival_time, lora_request, prefix)
|
||||
arrival_time, lora_request)
|
||||
|
||||
# Add the sequence group to the scheduler.
|
||||
self.scheduler.add_seq_group(seq_group)
|
||||
@ -752,6 +741,13 @@ class LLMEngine:
|
||||
now = time.time()
|
||||
# Update the scheduled sequence groups with the model outputs.
|
||||
scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups
|
||||
|
||||
# If prefix caching is enabled, mark all blocks in the sequence groups
|
||||
# as completed so that future requests don't attempt to recompute them
|
||||
if self.cache_config.enable_prefix_caching:
|
||||
for seq_group in scheduled_seq_groups:
|
||||
self.scheduler.mark_blocks_as_computed(seq_group)
|
||||
|
||||
for seq_group, outputs in zip(scheduled_seq_groups, output):
|
||||
self._process_sequence_group_outputs(seq_group, outputs)
|
||||
|
||||
@ -768,12 +764,6 @@ class LLMEngine:
|
||||
request_output = RequestOutput.from_seq_group(seq_group)
|
||||
request_outputs.append(request_output)
|
||||
|
||||
# Update prefix state, now all the uncomputed prefixes are computed.
|
||||
for seq_group in scheduled_seq_groups:
|
||||
if (seq_group.prefix is not None and seq_group.prefix.allocated
|
||||
and not seq_group.prefix.computed):
|
||||
seq_group.prefix.computed = True
|
||||
|
||||
# Log stats.
|
||||
if self.log_stats:
|
||||
self.stat_logger.log(self._get_stats(scheduler_outputs))
|
||||
|
@ -39,15 +39,11 @@ async def generate(request: Request) -> Response:
|
||||
"""
|
||||
request_dict = await request.json()
|
||||
prompt = request_dict.pop("prompt")
|
||||
prefix_pos = request_dict.pop("prefix_pos", None)
|
||||
stream = request_dict.pop("stream", False)
|
||||
sampling_params = SamplingParams(**request_dict)
|
||||
request_id = random_uuid()
|
||||
|
||||
results_generator = engine.generate(prompt,
|
||||
sampling_params,
|
||||
request_id,
|
||||
prefix_pos=prefix_pos)
|
||||
results_generator = engine.generate(prompt, sampling_params, request_id)
|
||||
|
||||
# Streaming case
|
||||
async def stream_results() -> AsyncGenerator[bytes, None]:
|
||||
|
@ -124,7 +124,6 @@ class LLM:
|
||||
prompts: Optional[Union[str, List[str]]] = None,
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||
prefix_pos: Optional[Union[int, List[int]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> List[RequestOutput]:
|
||||
@ -140,11 +139,6 @@ class LLM:
|
||||
None, we use the default sampling parameters.
|
||||
prompt_token_ids: A list of token IDs for the prompts. If None, we
|
||||
use the tokenizer to convert the prompts to token IDs.
|
||||
prefix_pos: If not None, we use the given position as the prefix
|
||||
position for each prompt. We will cache the prefix's KV
|
||||
cache and reuse it for the next request with the same prefix.
|
||||
This is an experimental feature, and may be replaced with
|
||||
automatic prefix caching in the future.
|
||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
|
||||
@ -171,14 +165,12 @@ class LLM:
|
||||
prompt_token_ids)
|
||||
for i in range(num_requests):
|
||||
prompt = prompts[i] if prompts is not None else None
|
||||
prefix_pos_i = prefix_pos[i] if prefix_pos is not None else None
|
||||
token_ids = None if prompt_token_ids is None else prompt_token_ids[
|
||||
i]
|
||||
self._add_request(prompt,
|
||||
sampling_params,
|
||||
token_ids,
|
||||
lora_request=lora_request,
|
||||
prefix_pos=prefix_pos_i)
|
||||
lora_request=lora_request)
|
||||
return self._run_engine(use_tqdm)
|
||||
|
||||
def _add_request(
|
||||
@ -187,15 +179,13 @@ class LLM:
|
||||
sampling_params: SamplingParams,
|
||||
prompt_token_ids: Optional[List[int]],
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prefix_pos: Optional[int] = None,
|
||||
) -> None:
|
||||
request_id = str(next(self.request_counter))
|
||||
self.llm_engine.add_request(request_id,
|
||||
prompt,
|
||||
sampling_params,
|
||||
prompt_token_ids,
|
||||
lora_request=lora_request,
|
||||
prefix_pos=prefix_pos)
|
||||
lora_request=lora_request)
|
||||
|
||||
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
|
||||
# Initialize tqdm.
|
||||
|
@ -1,87 +0,0 @@
|
||||
from typing import Dict, List, Sequence, Tuple, Optional
|
||||
|
||||
from vllm.block import BlockTable
|
||||
|
||||
|
||||
class Prefix:
|
||||
"""Data and states associated with a prefix of prompt tokens for multiple
|
||||
sequence groups.
|
||||
|
||||
NOTE: This feature is experimental and may be replaced with automatic
|
||||
prefix caching in the future.
|
||||
|
||||
Args:
|
||||
token_ids: The token ids of the prefix.
|
||||
block_size: The block size of the executed model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token_ids: Sequence[int],
|
||||
block_size: int,
|
||||
) -> None:
|
||||
self.token_ids = tuple(token_ids)
|
||||
self.block_size = block_size
|
||||
self.length = len(token_ids)
|
||||
self.hash = hash(token_ids)
|
||||
assert self.length % block_size == 0
|
||||
self.block_table: Optional[BlockTable] = None
|
||||
self.computed = False
|
||||
|
||||
@property
|
||||
def allocated(self) -> bool:
|
||||
return self.block_table is not None
|
||||
|
||||
def get_num_blocks(self) -> int:
|
||||
return self.length // self.block_size
|
||||
|
||||
def get_block_numbers(self) -> List[int]:
|
||||
return [block.block_number for block in self.block_table]
|
||||
|
||||
def get_length(self) -> int:
|
||||
return self.length
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.hash
|
||||
|
||||
def set_block_table(self, block_table: BlockTable) -> None:
|
||||
self.block_table = block_table.copy()
|
||||
|
||||
|
||||
class PrefixPool:
|
||||
"""Manages all the prompt prefixes.
|
||||
|
||||
NOTE: This feature is experimental and may be replaced with automatic
|
||||
prefix caching in the future.
|
||||
|
||||
Args:
|
||||
block_size: The block size of the executed model.
|
||||
|
||||
Attributes:
|
||||
prefixes: A list of all the prefixes.
|
||||
block_size: The block size of the executed model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
) -> None:
|
||||
# TODO(zhuohan): Add a capacity limit to the prefix pool.
|
||||
self.prefixes: Dict[int, Prefix] = {}
|
||||
self.block_size = block_size
|
||||
|
||||
def _truncate_token_ids(self, token_ids: Sequence[int]) -> Tuple[int]:
|
||||
new_length = len(token_ids) // self.block_size * self.block_size
|
||||
return tuple(token_ids[:new_length])
|
||||
|
||||
def add_or_get_prefix(self, token_ids: Sequence[int],
|
||||
lora_int_id: int) -> Optional[Prefix]:
|
||||
token_ids = self._truncate_token_ids(token_ids)
|
||||
if len(token_ids) == 0:
|
||||
# Prefix is empty.
|
||||
return None
|
||||
prefix = Prefix(token_ids, self.block_size)
|
||||
prefix_hash = hash((prefix, lora_int_id))
|
||||
if prefix_hash not in self.prefixes:
|
||||
self.prefixes[prefix_hash] = prefix
|
||||
return self.prefixes[prefix_hash]
|
@ -5,7 +5,6 @@ from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from vllm.block import LogicalTokenBlock
|
||||
from vllm.prefix import Prefix
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
@ -161,6 +160,16 @@ class Sequence:
|
||||
def lora_int_id(self) -> int:
|
||||
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||
|
||||
# TODO The current hashing function is O(L^2). We should optimize this in
|
||||
# the future.
|
||||
def hash_of_block(self, logical_idx: int) -> int:
|
||||
# Compute the number of tokens in the sequence
|
||||
num_tokens = self.num_hashed_tokens_of_block(logical_idx)
|
||||
return hash(tuple(self.data.get_token_ids()[0:num_tokens]))
|
||||
|
||||
def num_hashed_tokens_of_block(self, logical_idx: int):
|
||||
return logical_idx * self.block_size + self.block_size
|
||||
|
||||
def _append_logical_block(self) -> None:
|
||||
block = LogicalTokenBlock(
|
||||
block_number=len(self.logical_token_blocks),
|
||||
@ -265,7 +274,6 @@ class SequenceGroup:
|
||||
sampling_params: The sampling parameters used to generate the outputs.
|
||||
arrival_time: The arrival time of the request.
|
||||
lora_request: LoRA request.
|
||||
prefix: The prefix of the prompt of the sequence group.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -275,7 +283,6 @@ class SequenceGroup:
|
||||
sampling_params: SamplingParams,
|
||||
arrival_time: float,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prefix: Optional[Prefix] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.seqs_dict = {seq.seq_id: seq for seq in seqs}
|
||||
@ -286,7 +293,6 @@ class SequenceGroup:
|
||||
first_token_time=None,
|
||||
time_in_queue=None)
|
||||
self.lora_request = lora_request
|
||||
self.prefix: Optional[Prefix] = prefix
|
||||
self.prompt_logprobs: Optional[PromptLogprobs] = None
|
||||
self.state = SequenceGroupState()
|
||||
|
||||
@ -302,6 +308,10 @@ class SequenceGroup:
|
||||
# We use the prompt of an arbitrary sequence.
|
||||
return next(iter(self.seqs_dict.values())).data.prompt_token_ids
|
||||
|
||||
@property
|
||||
def block_size(self) -> int:
|
||||
return next(iter(self.seqs_dict.values())).block_size
|
||||
|
||||
@property
|
||||
def lora_int_id(self) -> int:
|
||||
return self.lora_request.lora_int_id if self.lora_request else 0
|
||||
@ -408,7 +418,6 @@ class SequenceGroupMetadata:
|
||||
numbers)
|
||||
state: Internal state tied to this sequence group.
|
||||
lora_request: LoRA request.
|
||||
prefix: The prefix of the prompt of the sequence group.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -419,7 +428,7 @@ class SequenceGroupMetadata:
|
||||
sampling_params: SamplingParams,
|
||||
block_tables: Dict[int, List[int]],
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prefix: Optional[Prefix] = None,
|
||||
computed_block_nums: Optional[List[int]] = None,
|
||||
state: Optional[SequenceGroupState] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
@ -428,7 +437,7 @@ class SequenceGroupMetadata:
|
||||
self.sampling_params = sampling_params
|
||||
self.block_tables = block_tables
|
||||
self.lora_request = lora_request
|
||||
self.prefix = prefix
|
||||
self.computed_block_nums = computed_block_nums
|
||||
self.state = SequenceGroupState() if state is None else state
|
||||
|
||||
@property
|
||||
|
@ -145,33 +145,37 @@ class ModelRunner:
|
||||
prompt_tokens = seq_data.get_token_ids()
|
||||
prompt_len = len(prompt_tokens)
|
||||
prompt_lens.append(prompt_len)
|
||||
prefix_len = 0
|
||||
prefix = seq_group_metadata.prefix
|
||||
if prefix is not None and prefix.computed:
|
||||
prefix_len = prefix.get_length()
|
||||
prompt_tokens = prompt_tokens[prefix_len:]
|
||||
prefix_block_tables.append(prefix.get_block_numbers())
|
||||
computed_len = 0
|
||||
|
||||
# NOTE: This only works for oooooooxxx style attention.
|
||||
computed_block_nums = seq_group_metadata.computed_block_nums
|
||||
if computed_block_nums is not None and len(
|
||||
computed_block_nums) > 0 and self.sliding_window is None:
|
||||
# Prefix is not supported with sliding_window
|
||||
computed_len = len(computed_block_nums) * self.block_size
|
||||
prompt_tokens = prompt_tokens[computed_len:]
|
||||
prefix_block_tables.append(computed_block_nums)
|
||||
else:
|
||||
prefix_block_tables.append([])
|
||||
# actual prompt lens
|
||||
context_lens.append(prefix_len)
|
||||
subquery_lens.append(prompt_len - prefix_len)
|
||||
context_lens.append(computed_len)
|
||||
subquery_lens.append(prompt_len - computed_len)
|
||||
|
||||
input_tokens.append(prompt_tokens)
|
||||
# NOTE(woosuk): Here we assume that the first token in the prompt
|
||||
# is always the first token in the sequence.
|
||||
input_positions.append(
|
||||
list(range(prefix_len, prefix_len + len(prompt_tokens))))
|
||||
list(range(computed_len, computed_len + len(prompt_tokens))))
|
||||
|
||||
lora_id = seq_group_metadata.lora_int_id
|
||||
|
||||
if lora_id > 0:
|
||||
lora_requests.add(seq_group_metadata.lora_request)
|
||||
|
||||
lora_index_mapping.append([lora_id] * (prompt_len - prefix_len))
|
||||
lora_index_mapping.append([lora_id] * (prompt_len - computed_len))
|
||||
lora_prompt_mapping.extend(
|
||||
[lora_id] *
|
||||
(prompt_len - prefix_len
|
||||
(prompt_len - computed_len
|
||||
if seq_group_metadata.sampling_params.prompt_logprobs else 1))
|
||||
|
||||
if seq_group_metadata.block_tables is None:
|
||||
@ -190,11 +194,11 @@ class ModelRunner:
|
||||
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
||||
start_idx = 0
|
||||
if self.sliding_window is not None:
|
||||
assert prefix_len == 0, (
|
||||
assert computed_len == 0, (
|
||||
"Prefix caching is currently not supported with "
|
||||
"sliding window attention")
|
||||
start_idx = max(0, prompt_len - self.sliding_window)
|
||||
for i in range(prefix_len, prompt_len):
|
||||
for i in range(computed_len, prompt_len):
|
||||
if i < start_idx:
|
||||
slot_mapping[-1].append(_PAD_SLOT_ID)
|
||||
continue
|
||||
|
Loading…
x
Reference in New Issue
Block a user