Possible fix for conflict between Automated Prefix Caching (#2762) and multi-LoRA support (#1804) (#3263)
This commit is contained in:
parent
385da2dae2
commit
8cbba4622c
@ -2,8 +2,11 @@
|
|||||||
|
|
||||||
Run `pytest tests/test_cache_block_hashing.py`.
|
Run `pytest tests/test_cache_block_hashing.py`.
|
||||||
"""
|
"""
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.transformers_utils.tokenizer import TokenizerGroup
|
from vllm.transformers_utils.tokenizer import TokenizerGroup
|
||||||
from vllm.sequence import Sequence
|
from vllm.sequence import Sequence
|
||||||
|
|
||||||
@ -36,7 +39,10 @@ def flatten_2d(li):
|
|||||||
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
|
||||||
@pytest.mark.parametrize("block_size", [16])
|
@pytest.mark.parametrize("block_size", [16])
|
||||||
@pytest.mark.parametrize("max_num_seqs", [256])
|
@pytest.mark.parametrize("max_num_seqs", [256])
|
||||||
def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int):
|
@pytest.mark.parametrize("concurrent_lora_int_ids",
|
||||||
|
[[None], [1], [None, 1], [None, 1, 2], [1, 2]])
|
||||||
|
def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
|
||||||
|
concurrent_lora_int_ids: List[Optional[int]]):
|
||||||
|
|
||||||
tokenizer = TokenizerGroup(
|
tokenizer = TokenizerGroup(
|
||||||
tokenizer_id="facebook/opt-125m",
|
tokenizer_id="facebook/opt-125m",
|
||||||
@ -48,20 +54,30 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int):
|
|||||||
hashes = []
|
hashes = []
|
||||||
|
|
||||||
for prefix in prefixes:
|
for prefix in prefixes:
|
||||||
hashes.append([])
|
for lora_int_id in concurrent_lora_int_ids:
|
||||||
prompts = [prefix + prompt for prompt in sample_prompts]
|
lora_request = None
|
||||||
seq_id = 0
|
|
||||||
for prompt in prompts:
|
|
||||||
hashes[-1].append([])
|
|
||||||
prompt_token_ids = tokenizer.encode(prompt)
|
|
||||||
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
|
|
||||||
tokenizer.tokenizer.eos_token_id)
|
|
||||||
|
|
||||||
num_blocks = len(prompt_token_ids) // block_size
|
if lora_int_id is not None:
|
||||||
for idx in range(num_blocks):
|
lora_request = LoRARequest(
|
||||||
hashes[-1][-1].append(seq.hash_of_block(idx))
|
f"example_lora_{lora_int_id}",
|
||||||
|
lora_int_id,
|
||||||
|
f"example/path/to/lora_{lora_int_id}",
|
||||||
|
)
|
||||||
|
|
||||||
seq_id += 1
|
hashes.append([])
|
||||||
|
prompts = [prefix + prompt for prompt in sample_prompts]
|
||||||
|
seq_id = 0
|
||||||
|
for prompt in prompts:
|
||||||
|
hashes[-1].append([])
|
||||||
|
prompt_token_ids = tokenizer.encode(prompt)
|
||||||
|
seq = Sequence(seq_id, prompt, prompt_token_ids, block_size,
|
||||||
|
tokenizer.tokenizer.eos_token_id, lora_request)
|
||||||
|
|
||||||
|
num_blocks = len(prompt_token_ids) // block_size
|
||||||
|
for idx in range(num_blocks):
|
||||||
|
hashes[-1][-1].append(seq.hash_of_block(idx))
|
||||||
|
|
||||||
|
seq_id += 1
|
||||||
|
|
||||||
# Check that hashes made with two prefixes with different first blocks are
|
# Check that hashes made with two prefixes with different first blocks are
|
||||||
# different everywhere.
|
# different everywhere.
|
||||||
|
@ -175,7 +175,8 @@ class Sequence:
|
|||||||
# TODO: The current hashing function is O(L^2). We should optimize
|
# TODO: The current hashing function is O(L^2). We should optimize
|
||||||
# this in the future.
|
# this in the future.
|
||||||
num_tokens = self.num_hashed_tokens_of_block(logical_idx)
|
num_tokens = self.num_hashed_tokens_of_block(logical_idx)
|
||||||
return hash(tuple(self.data.get_token_ids()[0:num_tokens]))
|
return hash(
|
||||||
|
(tuple(self.data.get_token_ids()[0:num_tokens]), self.lora_int_id))
|
||||||
|
|
||||||
def num_hashed_tokens_of_block(self, logical_idx: int):
|
def num_hashed_tokens_of_block(self, logical_idx: int):
|
||||||
return logical_idx * self.block_size + self.block_size
|
return logical_idx * self.block_size + self.block_size
|
||||||
|
Loading…
x
Reference in New Issue
Block a user