[Core][Optimization] change python dict to pytorch tensor for blocks to swap (#4659)

This commit is contained in:
youkaichao 2024-05-08 12:07:05 -07:00 committed by GitHub
parent ad932a221d
commit 20cfcdec99
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 137 additions and 109 deletions

View File

@ -8,12 +8,12 @@
void swap_blocks(
torch::Tensor& src,
torch::Tensor& dst,
const std::map<int64_t, int64_t>& block_mapping);
const torch::Tensor& block_mapping);
void copy_blocks(
std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
torch::Tensor& block_mapping);
const torch::Tensor& block_mapping);
void reshape_and_cache(
torch::Tensor& key,

View File

@ -23,7 +23,7 @@
void swap_blocks(
torch::Tensor& src,
torch::Tensor& dst,
const std::map<int64_t, int64_t>& block_mapping) {
const torch::Tensor& block_mapping) {
torch::Device src_device = src.device();
torch::Device dst_device = dst.device();
cudaMemcpyKind memcpy_type;
@ -40,6 +40,11 @@ void swap_blocks(
TORCH_CHECK(false, "Invalid device combination");
}
// NOTE(youkaichao): keep in mind that `block_mapping` should be
// a cpu tensor, otherwise every `item` call will require a gpu-cpu
// synchronization.
TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
char *src_ptr = static_cast<char*>(src.data_ptr());
char *dst_ptr = static_cast<char*>(dst.data_ptr());
@ -47,9 +52,10 @@ void swap_blocks(
const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// NOTE(woosuk): This can be slow if the number of blocks is large.
for (const auto& pair : block_mapping) {
int64_t src_block_number = pair.first;
int64_t dst_block_number = pair.second;
const int64_t num_blocks = block_mapping.size(0);
for (size_t i = 0; i < num_blocks; i++) {
int64_t src_block_number = block_mapping[i][0].item<int64_t>();
int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
int64_t src_offset = src_block_number * block_size_in_bytes;
int64_t dst_offset = dst_block_number * block_size_in_bytes;
cudaMemcpyAsync(
@ -97,7 +103,7 @@ __global__ void copy_blocks_kernel(
void copy_blocks(
std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches,
torch::Tensor& block_mapping) {
const torch::Tensor& block_mapping) {
int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) {

View File

@ -83,7 +83,7 @@ void reshape_and_cache_cpu_impl(
void copy_blocks(std::vector<torch::Tensor> &key_caches,
std::vector<torch::Tensor> &value_caches,
torch::Tensor& block_mapping) {
const torch::Tensor& block_mapping) {
int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) {
@ -128,6 +128,6 @@ void reshape_and_cache(torch::Tensor &key, torch::Tensor &value,
}
void swap_blocks(torch::Tensor &src, torch::Tensor &dst,
const std::map<int64_t, int64_t> &block_mapping) {
const torch::Tensor&block_mapping) {
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
}

View File

@ -219,7 +219,7 @@ def test_swap():
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
mapping = block_manager.swap_out(seq_group)
assert list(mapping.keys()) == gpu_blocks
assert [x[0] for x in mapping] == gpu_blocks
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks)
@ -232,7 +232,7 @@ def test_swap():
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
mapping = block_manager.swap_in(seq_group)
assert list(mapping.keys()) == cpu_blocks
assert [x[0] for x in mapping] == cpu_blocks
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks

View File

@ -355,8 +355,8 @@ def test_swap():
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 0
assert out.num_batched_tokens == 0
assert out.blocks_to_swap_out != {}
assert out.blocks_to_swap_in == {}
assert out.blocks_to_swap_out != []
assert out.blocks_to_swap_in == []
# Add 1 more task. Swap should be prioritized over new prefill.
_, seq_group = create_dummy_prompt("2", prompt_length=60)
@ -365,8 +365,8 @@ def test_swap():
assert len(out.scheduled_seq_groups) == 1
# 3 decodes. It is swapped in.
assert out.num_batched_tokens == 30
assert out.blocks_to_swap_in != {}
assert out.blocks_to_swap_out == {}
assert out.blocks_to_swap_in != []
assert out.blocks_to_swap_out == []
def test_running_prefill_prioritized_over_swap():
@ -406,8 +406,8 @@ def test_running_prefill_prioritized_over_swap():
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 0
assert out.num_batched_tokens == 0
assert out.blocks_to_swap_out != {}
assert out.blocks_to_swap_in == {}
assert out.blocks_to_swap_out != []
assert out.blocks_to_swap_in == []
# Add 1 more task. Swap is not possible, so prefill is running.
scheduler.block_manager.can_swap_in = MagicMock()
@ -419,8 +419,8 @@ def test_running_prefill_prioritized_over_swap():
assert len(out.scheduled_seq_groups) == 1
# 3 decodes. It is swapped in.
assert out.num_batched_tokens == 30
assert out.blocks_to_swap_in == {}
assert out.blocks_to_swap_out == {}
assert out.blocks_to_swap_in == []
assert out.blocks_to_swap_out == []
assert out.scheduled_seq_groups[0].seq_group == seq_group2
# Now although swap is possible, running prefill is prioritized.
@ -429,8 +429,8 @@ def test_running_prefill_prioritized_over_swap():
assert len(out.scheduled_seq_groups) == 1
# 3 decodes. It is swapped in.
assert out.num_batched_tokens == 30
assert out.blocks_to_swap_in == {}
assert out.blocks_to_swap_out == {}
assert out.blocks_to_swap_in == []
assert out.blocks_to_swap_out == []
assert not seq_group2.is_prefill()
assert out.scheduled_seq_groups[0].seq_group == seq_group2
append_new_token(seq_group2, 1)
@ -440,8 +440,8 @@ def test_running_prefill_prioritized_over_swap():
assert len(out.scheduled_seq_groups) == 1
# 3 decodes. It is swapped in.
assert out.num_batched_tokens == 1
assert out.blocks_to_swap_in == {}
assert out.blocks_to_swap_out == {}
assert out.blocks_to_swap_in == []
assert out.blocks_to_swap_out == []
assert not seq_group2.is_prefill()
assert out.scheduled_seq_groups[0].seq_group == seq_group2
append_new_token(seq_group2, 1)
@ -451,8 +451,8 @@ def test_running_prefill_prioritized_over_swap():
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 1
assert out.num_batched_tokens == 30
assert out.blocks_to_swap_in != {}
assert out.blocks_to_swap_out == {}
assert out.blocks_to_swap_in != []
assert out.blocks_to_swap_out == []
def test_chunked_prefill_preempt():
@ -493,8 +493,8 @@ def test_chunked_prefill_preempt():
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 0
assert out.num_batched_tokens == 0
assert out.blocks_to_swap_out == {}
assert out.blocks_to_swap_in == {}
assert out.blocks_to_swap_out == []
assert out.blocks_to_swap_in == []
# Make sure we can reschedule preempted request.
_, out = schedule_and_update_computed_tokens(scheduler)

View File

@ -293,8 +293,8 @@ def test_swapped_out_prioritized():
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 2
assert out.num_batched_tokens == 2
assert out.blocks_to_swap_out != {}
assert out.blocks_to_swap_in == {}
assert out.blocks_to_swap_out != []
assert out.blocks_to_swap_in == []
append_new_token(out, 1)
# Add 1 more task. Swap should be prioritized over prefill.
@ -305,8 +305,8 @@ def test_swapped_out_prioritized():
assert len(out.scheduled_seq_groups) == 3
# 3 decodes. It is swapped in.
assert out.num_batched_tokens == 3
assert out.blocks_to_swap_in != {}
assert out.blocks_to_swap_out == {}
assert out.blocks_to_swap_in != []
assert out.blocks_to_swap_out == []
def initialize_scheduler(*,
@ -566,7 +566,7 @@ def test_decode_schedule_preempted():
# NOTE: When enable_chunk is False, num_seqs budget is not updated.
# assert budget.num_curr_seqs == 1
# Both should be preempted, not swapped.
assert output.blocks_to_swap_out == {}
assert output.blocks_to_swap_out == []
# Nothing is copied.
assert output.blocks_to_copy == []
@ -599,7 +599,7 @@ def test_decode_swap_beam_search():
scheduler.block_manager.can_append_slots.side_effect = (
cannot_append_second_group)
scheduler.block_manager.swap_out = MagicMock()
expected_swap_mapping = {"5": "7"}
expected_swap_mapping = [("5", "7")]
scheduler.block_manager.swap_out.return_value = expected_swap_mapping
remainig_running, output = scheduler._schedule_running(
@ -647,7 +647,7 @@ def test_schedule_decode_blocks_to_copy_update():
assert len(output.preempted) == 0
assert len(output.swapped_out) == 0
# Nothing is preempted.
assert output.blocks_to_swap_out == {}
assert output.blocks_to_swap_out == []
# Since append_slot returns the source -> dist mapping, it should
# applied.
assert output.blocks_to_copy == [(2, 3)]
@ -658,7 +658,7 @@ def test_schedule_swapped_simple():
swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
blocks_to_swap_out = {}
blocks_to_swap_out = []
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
@ -674,9 +674,9 @@ def test_schedule_swapped_simple():
assert len(output.decode_seq_groups) == 1
assert len(output.prefill_seq_groups) == 0
# swap in is the reverse of swap out
blocks_to_swap_in_reverse = {}
for swapin, swapout in output.blocks_to_swap_in.items():
blocks_to_swap_in_reverse[swapout] = swapin
blocks_to_swap_in_reverse = []
for swapin, swapout in output.blocks_to_swap_in:
blocks_to_swap_in_reverse.append((swapout, swapin))
assert blocks_to_swap_out == blocks_to_swap_in_reverse
@ -685,7 +685,7 @@ def test_schedule_swapped_max_token_budget():
swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
blocks_to_swap_out = {}
blocks_to_swap_out = []
for _ in range(2):
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
@ -719,7 +719,7 @@ def test_schedule_swapped_max_seqs():
swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
blocks_to_swap_out = {}
blocks_to_swap_out = []
for i in range(4):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
scheduler._allocate_and_set_running(seq_group)
@ -752,7 +752,7 @@ def test_schedule_swapped_max_loras():
swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = set()
blocks_to_swap_out = {}
blocks_to_swap_out = []
for i in range(2):
_, seq_group = create_dummy_prompt(str(i),
prompt_length=60,
@ -781,7 +781,7 @@ def test_schedule_swapped_cannot_swap_in():
swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
blocks_to_swap_out = {}
blocks_to_swap_out = []
for _ in range(2):
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
@ -808,7 +808,7 @@ def test_infeasible_swap():
swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None
blocks_to_swap_out = {}
blocks_to_swap_out = []
for _ in range(2):
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
@ -839,7 +839,7 @@ def test_schedule_swapped_blocks_to_copy():
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1)
blocks_to_swap_out = {}
blocks_to_swap_out = []
scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group)

View File

@ -315,7 +315,10 @@ def test_swap_blocks(
else:
dst_blocks = random.sample(range(num_blocks), num_mappings)
block_mapping = dict(zip(src_blocks, dst_blocks))
block_mapping = list(zip(src_blocks, dst_blocks))
block_mapping_tensor = torch.tensor(block_mapping,
dtype=torch.int64,
device="cpu").view(-1, 2)
# Create the KV caches on the first device.
src_key_caches, src_value_caches = kv_cache_factory(
@ -331,10 +334,12 @@ def test_swap_blocks(
src_value_caches_clone = src_value_caches[0].clone()
# Call the swap_blocks kernel.
ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping)
ops.swap_blocks(src_key_caches[0], dist_key_caches[0],
block_mapping_tensor)
ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
block_mapping_tensor)
for src, dst in block_mapping.items():
for src, dst in block_mapping:
assert torch.allclose(src_key_caches_clone[src].cpu(),
dist_key_caches[0][dst].cpu())
assert torch.allclose(src_value_caches_clone[src].cpu(),

View File

@ -54,10 +54,10 @@ def test_swap() -> None:
a.cuda(), b.cuda(), rtol=0.0, atol=0.0)
# Test swap out.
blocks_to_swap_out = {3: 72, 56: 35, 84: 34}
blocks_to_swap_out = [(3, 72), (56, 35), (84, 34)]
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=[],
blocks_to_swap_in={},
blocks_to_swap_in=[],
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=[],
)
@ -66,24 +66,24 @@ def test_swap() -> None:
for i in range(num_layers):
gpu_key_cache, gpu_value_cache = gpu_cache[i]
cpu_key_cache, cpu_value_cache = cpu_cache[i]
for src, dst in blocks_to_swap_out.items():
for src, dst in blocks_to_swap_out:
assert allclose(gpu_key_cache[src], cpu_key_cache[dst])
assert allclose(gpu_value_cache[src], cpu_value_cache[dst])
# Test swap in.
execute_model_req.blocks_to_swap_out = {}
execute_model_req.blocks_to_swap_in = {
19: 45,
67: 23,
12: 78,
40: 99,
1: 71
}
execute_model_req.blocks_to_swap_out = []
execute_model_req.blocks_to_swap_in = [
(19, 45),
(67, 23),
(12, 78),
(40, 99),
(1, 71),
]
worker.execute_model(execute_model_req=execute_model_req)
for i in range(num_layers):
gpu_key_cache, gpu_value_cache = gpu_cache[i]
cpu_key_cache, cpu_value_cache = cpu_cache[i]
for src, dst in execute_model_req.blocks_to_swap_in.items():
for src, dst in execute_model_req.blocks_to_swap_in:
assert allclose(gpu_key_cache[dst], cpu_key_cache[src])
assert allclose(gpu_value_cache[dst], cpu_value_cache[src])

View File

@ -39,7 +39,7 @@ class AttentionBackend(ABC):
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
src_to_dst: torch.Tensor,
) -> None:
raise NotImplementedError

View File

@ -5,7 +5,7 @@ XFormers backend. The duplicated code will be removed once we use flash-attn or
flashinfer for all the attention operations.
"""
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type
from typing import List, Optional, Tuple, Type
import torch
from flash_attn import flash_attn_varlen_func
@ -45,7 +45,7 @@ class FlashAttentionBackend(AttentionBackend):
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
src_to_dst: torch.Tensor,
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

View File

@ -39,7 +39,7 @@ class FlashInferBackend(AttentionBackend):
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
src_to_dst: torch.Tensor,
) -> None:
raise NotImplementedError

View File

@ -1,6 +1,6 @@
"""Attention layer ROCm GPUs."""
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type
from typing import List, Optional, Tuple, Type
import torch
@ -43,7 +43,7 @@ class ROCmFlashAttentionBackend(AttentionBackend):
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
src_to_dst: torch.Tensor,
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

View File

@ -1,7 +1,7 @@
""" Attention layer with torch scaled_dot_product_attention
and PagedAttention."""
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type
from typing import List, Optional, Tuple, Type
import torch
from torch.nn.functional import scaled_dot_product_attention
@ -41,7 +41,7 @@ class TorchSDPABackend(AttentionBackend):
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
src_to_dst: torch.Tensor,
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
import torch
@ -196,7 +196,7 @@ class PagedAttention:
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
src_to_dst: torch.Tensor,
) -> None:
src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0]

View File

@ -473,11 +473,12 @@ class BlockSpaceManagerV1(BlockSpaceManager):
def swap_in(self,
seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> Dict[int, int]:
num_lookahead_slots: int = 0) -> List[Tuple[int, int]]:
assert (num_lookahead_slots == 0
), "BlockSpaceManagerV1 does not support lookahead allocation"
# CPU block -> GPU block.
# dict is efficient in lookup `if cpu_block in mapping`
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
new_block_table: BlockTable = []
@ -500,14 +501,16 @@ class BlockSpaceManagerV1(BlockSpaceManager):
cpu_block.block_number: gpu_block.block_number
for cpu_block, gpu_block in mapping.items()
}
return block_number_mapping
# convert to list of tuples once here
return list(block_number_mapping.items())
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
blocks = self._get_physical_blocks(seq_group)
return len(blocks) <= self.cpu_allocator.get_num_free_blocks()
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
# GPU block -> CPU block.
# dict is efficient in lookup `if gpu_block in mapping`
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
new_block_table: BlockTable = []
@ -530,7 +533,8 @@ class BlockSpaceManagerV1(BlockSpaceManager):
gpu_block.block_number: cpu_block.block_number
for gpu_block, cpu_block in mapping.items()
}
return block_number_mapping
# convert to list of tuples once here
return list(block_number_mapping.items())
def _free_block_table(self, block_table: BlockTable) -> None:
# when using a sliding window, each seq will only use up

View File

@ -243,13 +243,13 @@ class BlockSpaceManagerV2(BlockSpaceManager):
return AllocStatus.LATER
def swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> Dict[int, int]:
num_lookahead_slots: int) -> List[Tuple[int, int]]:
raise NotImplementedError
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
return False
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
raise NotImplementedError
def get_num_free_gpu_blocks(self) -> int:

View File

@ -1,6 +1,6 @@
import enum
from abc import ABC, abstractmethod
from typing import Dict, List
from typing import List
from typing import Sequence as GenericSequence
from typing import Tuple
@ -69,7 +69,7 @@ class BlockSpaceManager(ABC):
@abstractmethod
def swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> Dict[int, int]:
num_lookahead_slots: int) -> List[Tuple[int, int]]:
pass
@abstractmethod
@ -77,7 +77,7 @@ class BlockSpaceManager(ABC):
pass
@abstractmethod
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
pass
@abstractmethod

View File

@ -117,10 +117,10 @@ class SchedulerOutputs:
num_prefill_groups: int
# Total number of batched tokens.
num_batched_tokens: int
# Blocks to swap in. Dict of CPU -> GPU block number.
blocks_to_swap_in: Dict[int, int]
# Blocks to swap out. Dict of GPU -> CPU block number.
blocks_to_swap_out: Dict[int, int]
# Blocks to swap in. List of CPU -> GPU block number.
blocks_to_swap_in: List[Tuple[int, int]]
# Blocks to swap out. List of GPU -> CPU block number.
blocks_to_swap_out: List[Tuple[int, int]]
# Blocks to copy. Source to dest block.
blocks_to_copy: List[Tuple[int, int]]
# Sequence groups that are going to be ignored.
@ -174,7 +174,7 @@ class SchedulerRunningOutputs:
# Sequences that are swapped out.
swapped_out: List[SequenceGroup]
# The blocks to swap out.
blocks_to_swap_out: Dict[int, int]
blocks_to_swap_out: List[Tuple[int, int]]
# The blocks to copy.
blocks_to_copy: List[Tuple[int, int]]
# The number of slots for lookahead decoding.
@ -187,7 +187,7 @@ class SchedulerRunningOutputs:
prefill_seq_groups=[],
preempted=[],
swapped_out=[],
blocks_to_swap_out={},
blocks_to_swap_out=[],
blocks_to_copy=[],
num_lookahead_slots=0,
)
@ -206,7 +206,7 @@ class SchedulerSwappedInOutputs:
# phase. I.e., it means the prefill has been chunked.
prefill_seq_groups: List[SequenceGroup]
# The blocks to swap in.
blocks_to_swap_in: Dict[int, int]
blocks_to_swap_in: List[Tuple[int, int]]
# The blocks to copy.
blocks_to_copy: List[Tuple[int, int]]
# The number of slots for lookahead decoding.
@ -219,7 +219,7 @@ class SchedulerSwappedInOutputs:
return SchedulerSwappedInOutputs(
decode_seq_groups=[],
prefill_seq_groups=[],
blocks_to_swap_in={},
blocks_to_swap_in=[],
blocks_to_copy=[],
num_lookahead_slots=0,
infeasible_seq_groups=[],
@ -392,7 +392,7 @@ class Scheduler:
scheduling and SchedulerRunningOutputs.
"""
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_out: Dict[int, int] = {}
blocks_to_swap_out: List[Tuple[int, int]] = []
blocks_to_copy: List[Tuple[int, int]] = []
decode_seq_groups: List[ScheduledSequenceGroup] = []
@ -509,7 +509,7 @@ class Scheduler:
SchedulerSwappedInOutputs.
"""
# Blocks that need to be swapped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {}
blocks_to_swap_in: List[Tuple[int, int]] = []
blocks_to_copy: List[Tuple[int, int]] = []
decode_seq_groups: List[ScheduledSequenceGroup] = []
prefill_seq_groups: List[ScheduledSequenceGroup] = []
@ -1032,7 +1032,7 @@ class Scheduler:
def _preempt(
self,
seq_group: SequenceGroup,
blocks_to_swap_out: Dict[int, int],
blocks_to_swap_out: List[Tuple[int, int]],
preemption_mode: Optional[PreemptionMode] = None,
) -> PreemptionMode:
# If preemption mode is not specified, we determine the mode as follows:
@ -1073,24 +1073,24 @@ class Scheduler:
def _preempt_by_swap(
self,
seq_group: SequenceGroup,
blocks_to_swap_out: Dict[int, int],
blocks_to_swap_out: List[Tuple[int, int]],
) -> None:
self._swap_out(seq_group, blocks_to_swap_out)
def _swap_in(
self,
seq_group: SequenceGroup,
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_in: List[Tuple[int, int]],
) -> None:
mapping = self.block_manager.swap_in(seq_group)
blocks_to_swap_in.update(mapping)
blocks_to_swap_in.extend(mapping)
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
seq.status = SequenceStatus.RUNNING
def _swap_out(
self,
seq_group: SequenceGroup,
blocks_to_swap_out: Dict[int, int],
blocks_to_swap_out: List[Tuple[int, int]],
) -> None:
if not self.block_manager.can_swap_out(seq_group):
# FIXME(woosuk): Abort the sequence group instead of aborting the
@ -1099,7 +1099,7 @@ class Scheduler:
"Aborted due to the lack of CPU swap space. Please increase "
"the swap space to avoid this error.")
mapping = self.block_manager.swap_out(seq_group)
blocks_to_swap_out.update(mapping)
blocks_to_swap_out.extend(mapping)
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq.status = SequenceStatus.SWAPPED

View File

@ -741,10 +741,10 @@ class ExecuteModelRequest:
"""The model execution request."""
# The sequence group metadata list.
seq_group_metadata_list: List[SequenceGroupMetadata]
# Blocks to swap in. Dict of CPU -> GPU block number.
blocks_to_swap_in: Dict[int, int] = field(default_factory=dict)
# Blocks to swap out. Dict of GPU -> CPU block number.
blocks_to_swap_out: Dict[int, int] = field(default_factory=dict)
# Blocks to swap in. List of CPU -> GPU block number.
blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list)
# Blocks to swap out. List of GPU -> CPU block number.
blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
# Blocks to copy. Source to dest block.
blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
# The number of slots for lookahead decoding.

View File

@ -1,5 +1,5 @@
"""CacheEngine class for managing the KV cache."""
from typing import Dict, List
from typing import List
import torch
@ -67,12 +67,12 @@ class CacheEngine:
device=device))
return kv_cache
def swap_in(self, src_to_dst: Dict[int, int]) -> None:
def swap_in(self, src_to_dst: torch.Tensor) -> None:
for i in range(self.num_layers):
self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
src_to_dst)
def swap_out(self, src_to_dst: Dict[int, int]) -> None:
def swap_out(self, src_to_dst: torch.Tensor) -> None:
for i in range(self.num_layers):
self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
src_to_dst)

View File

@ -195,15 +195,14 @@ class Worker(WorkerBase):
def cache_swap(
self,
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_swap_in: torch.Tensor,
blocks_to_swap_out: torch.Tensor,
blocks_to_copy: torch.Tensor,
) -> None:
# Issue cache operations.
# TODO(woosuk): Profile swapping overhead and optimize if needed.
if blocks_to_swap_in:
if blocks_to_swap_in.numel() > 0:
self.cache_engine.swap_in(blocks_to_swap_in)
if blocks_to_swap_out:
if blocks_to_swap_out.numel() > 0:
self.cache_engine.swap_out(blocks_to_swap_out)
if blocks_to_copy.numel() > 0:
self.cache_engine.copy(blocks_to_copy)
@ -219,12 +218,26 @@ class Worker(WorkerBase):
else:
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
blocks_to_swap_in: torch.Tensor
blocks_to_swap_out: torch.Tensor
blocks_to_copy: torch.Tensor
if self.is_driver_worker:
assert seq_group_metadata_list is not None
assert execute_model_req is not None
num_seq_groups = len(seq_group_metadata_list)
blocks_to_swap_in = execute_model_req.blocks_to_swap_in
blocks_to_swap_out = execute_model_req.blocks_to_swap_out
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
# they contain parameters to launch cudamemcpyasync.
blocks_to_swap_in = torch.tensor(
execute_model_req.blocks_to_swap_in,
device="cpu",
dtype=torch.int64).view(-1, 2)
blocks_to_swap_out = torch.tensor(
execute_model_req.blocks_to_swap_out,
device="cpu",
dtype=torch.int64).view(-1, 2)
# `blocks_to_copy` is a gpu tensor. The src and tgt of
# blocks to copy are in the same device, and `blocks_to_copy`
# can be used directly within cuda kernels.
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device=self.device,
dtype=torch.int64).view(-1, 2)