[Core][Optimization] change python dict to pytorch tensor for blocks to swap (#4659)
This commit is contained in:
parent
ad932a221d
commit
20cfcdec99
@ -8,12 +8,12 @@
|
||||
void swap_blocks(
|
||||
torch::Tensor& src,
|
||||
torch::Tensor& dst,
|
||||
const std::map<int64_t, int64_t>& block_mapping);
|
||||
const torch::Tensor& block_mapping);
|
||||
|
||||
void copy_blocks(
|
||||
std::vector<torch::Tensor>& key_caches,
|
||||
std::vector<torch::Tensor>& value_caches,
|
||||
torch::Tensor& block_mapping);
|
||||
const torch::Tensor& block_mapping);
|
||||
|
||||
void reshape_and_cache(
|
||||
torch::Tensor& key,
|
||||
|
@ -23,7 +23,7 @@
|
||||
void swap_blocks(
|
||||
torch::Tensor& src,
|
||||
torch::Tensor& dst,
|
||||
const std::map<int64_t, int64_t>& block_mapping) {
|
||||
const torch::Tensor& block_mapping) {
|
||||
torch::Device src_device = src.device();
|
||||
torch::Device dst_device = dst.device();
|
||||
cudaMemcpyKind memcpy_type;
|
||||
@ -40,6 +40,11 @@ void swap_blocks(
|
||||
TORCH_CHECK(false, "Invalid device combination");
|
||||
}
|
||||
|
||||
// NOTE(youkaichao): keep in mind that `block_mapping` should be
|
||||
// a cpu tensor, otherwise every `item` call will require a gpu-cpu
|
||||
// synchronization.
|
||||
TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
|
||||
|
||||
char *src_ptr = static_cast<char*>(src.data_ptr());
|
||||
char *dst_ptr = static_cast<char*>(dst.data_ptr());
|
||||
|
||||
@ -47,9 +52,10 @@ void swap_blocks(
|
||||
const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
// NOTE(woosuk): This can be slow if the number of blocks is large.
|
||||
for (const auto& pair : block_mapping) {
|
||||
int64_t src_block_number = pair.first;
|
||||
int64_t dst_block_number = pair.second;
|
||||
const int64_t num_blocks = block_mapping.size(0);
|
||||
for (size_t i = 0; i < num_blocks; i++) {
|
||||
int64_t src_block_number = block_mapping[i][0].item<int64_t>();
|
||||
int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
|
||||
int64_t src_offset = src_block_number * block_size_in_bytes;
|
||||
int64_t dst_offset = dst_block_number * block_size_in_bytes;
|
||||
cudaMemcpyAsync(
|
||||
@ -97,7 +103,7 @@ __global__ void copy_blocks_kernel(
|
||||
void copy_blocks(
|
||||
std::vector<torch::Tensor>& key_caches,
|
||||
std::vector<torch::Tensor>& value_caches,
|
||||
torch::Tensor& block_mapping) {
|
||||
const torch::Tensor& block_mapping) {
|
||||
int num_layers = key_caches.size();
|
||||
TORCH_CHECK(num_layers == value_caches.size());
|
||||
if (num_layers == 0) {
|
||||
|
@ -83,7 +83,7 @@ void reshape_and_cache_cpu_impl(
|
||||
|
||||
void copy_blocks(std::vector<torch::Tensor> &key_caches,
|
||||
std::vector<torch::Tensor> &value_caches,
|
||||
torch::Tensor& block_mapping) {
|
||||
const torch::Tensor& block_mapping) {
|
||||
int num_layers = key_caches.size();
|
||||
TORCH_CHECK(num_layers == value_caches.size());
|
||||
if (num_layers == 0) {
|
||||
@ -128,6 +128,6 @@ void reshape_and_cache(torch::Tensor &key, torch::Tensor &value,
|
||||
}
|
||||
|
||||
void swap_blocks(torch::Tensor &src, torch::Tensor &dst,
|
||||
const std::map<int64_t, int64_t> &block_mapping) {
|
||||
const torch::Tensor&block_mapping) {
|
||||
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
|
||||
}
|
||||
|
@ -219,7 +219,7 @@ def test_swap():
|
||||
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
mapping = block_manager.swap_out(seq_group)
|
||||
assert list(mapping.keys()) == gpu_blocks
|
||||
assert [x[0] for x in mapping] == gpu_blocks
|
||||
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks)
|
||||
@ -232,7 +232,7 @@ def test_swap():
|
||||
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
mapping = block_manager.swap_in(seq_group)
|
||||
assert list(mapping.keys()) == cpu_blocks
|
||||
assert [x[0] for x in mapping] == cpu_blocks
|
||||
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||
assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks
|
||||
|
@ -355,8 +355,8 @@ def test_swap():
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 0
|
||||
assert out.num_batched_tokens == 0
|
||||
assert out.blocks_to_swap_out != {}
|
||||
assert out.blocks_to_swap_in == {}
|
||||
assert out.blocks_to_swap_out != []
|
||||
assert out.blocks_to_swap_in == []
|
||||
|
||||
# Add 1 more task. Swap should be prioritized over new prefill.
|
||||
_, seq_group = create_dummy_prompt("2", prompt_length=60)
|
||||
@ -365,8 +365,8 @@ def test_swap():
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
# 3 decodes. It is swapped in.
|
||||
assert out.num_batched_tokens == 30
|
||||
assert out.blocks_to_swap_in != {}
|
||||
assert out.blocks_to_swap_out == {}
|
||||
assert out.blocks_to_swap_in != []
|
||||
assert out.blocks_to_swap_out == []
|
||||
|
||||
|
||||
def test_running_prefill_prioritized_over_swap():
|
||||
@ -406,8 +406,8 @@ def test_running_prefill_prioritized_over_swap():
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 0
|
||||
assert out.num_batched_tokens == 0
|
||||
assert out.blocks_to_swap_out != {}
|
||||
assert out.blocks_to_swap_in == {}
|
||||
assert out.blocks_to_swap_out != []
|
||||
assert out.blocks_to_swap_in == []
|
||||
|
||||
# Add 1 more task. Swap is not possible, so prefill is running.
|
||||
scheduler.block_manager.can_swap_in = MagicMock()
|
||||
@ -419,8 +419,8 @@ def test_running_prefill_prioritized_over_swap():
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
# 3 decodes. It is swapped in.
|
||||
assert out.num_batched_tokens == 30
|
||||
assert out.blocks_to_swap_in == {}
|
||||
assert out.blocks_to_swap_out == {}
|
||||
assert out.blocks_to_swap_in == []
|
||||
assert out.blocks_to_swap_out == []
|
||||
assert out.scheduled_seq_groups[0].seq_group == seq_group2
|
||||
|
||||
# Now although swap is possible, running prefill is prioritized.
|
||||
@ -429,8 +429,8 @@ def test_running_prefill_prioritized_over_swap():
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
# 3 decodes. It is swapped in.
|
||||
assert out.num_batched_tokens == 30
|
||||
assert out.blocks_to_swap_in == {}
|
||||
assert out.blocks_to_swap_out == {}
|
||||
assert out.blocks_to_swap_in == []
|
||||
assert out.blocks_to_swap_out == []
|
||||
assert not seq_group2.is_prefill()
|
||||
assert out.scheduled_seq_groups[0].seq_group == seq_group2
|
||||
append_new_token(seq_group2, 1)
|
||||
@ -440,8 +440,8 @@ def test_running_prefill_prioritized_over_swap():
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
# 3 decodes. It is swapped in.
|
||||
assert out.num_batched_tokens == 1
|
||||
assert out.blocks_to_swap_in == {}
|
||||
assert out.blocks_to_swap_out == {}
|
||||
assert out.blocks_to_swap_in == []
|
||||
assert out.blocks_to_swap_out == []
|
||||
assert not seq_group2.is_prefill()
|
||||
assert out.scheduled_seq_groups[0].seq_group == seq_group2
|
||||
append_new_token(seq_group2, 1)
|
||||
@ -451,8 +451,8 @@ def test_running_prefill_prioritized_over_swap():
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 1
|
||||
assert out.num_batched_tokens == 30
|
||||
assert out.blocks_to_swap_in != {}
|
||||
assert out.blocks_to_swap_out == {}
|
||||
assert out.blocks_to_swap_in != []
|
||||
assert out.blocks_to_swap_out == []
|
||||
|
||||
|
||||
def test_chunked_prefill_preempt():
|
||||
@ -493,8 +493,8 @@ def test_chunked_prefill_preempt():
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 0
|
||||
assert out.num_batched_tokens == 0
|
||||
assert out.blocks_to_swap_out == {}
|
||||
assert out.blocks_to_swap_in == {}
|
||||
assert out.blocks_to_swap_out == []
|
||||
assert out.blocks_to_swap_in == []
|
||||
|
||||
# Make sure we can reschedule preempted request.
|
||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||
|
@ -293,8 +293,8 @@ def test_swapped_out_prioritized():
|
||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||
assert len(out.scheduled_seq_groups) == 2
|
||||
assert out.num_batched_tokens == 2
|
||||
assert out.blocks_to_swap_out != {}
|
||||
assert out.blocks_to_swap_in == {}
|
||||
assert out.blocks_to_swap_out != []
|
||||
assert out.blocks_to_swap_in == []
|
||||
append_new_token(out, 1)
|
||||
|
||||
# Add 1 more task. Swap should be prioritized over prefill.
|
||||
@ -305,8 +305,8 @@ def test_swapped_out_prioritized():
|
||||
assert len(out.scheduled_seq_groups) == 3
|
||||
# 3 decodes. It is swapped in.
|
||||
assert out.num_batched_tokens == 3
|
||||
assert out.blocks_to_swap_in != {}
|
||||
assert out.blocks_to_swap_out == {}
|
||||
assert out.blocks_to_swap_in != []
|
||||
assert out.blocks_to_swap_out == []
|
||||
|
||||
|
||||
def initialize_scheduler(*,
|
||||
@ -566,7 +566,7 @@ def test_decode_schedule_preempted():
|
||||
# NOTE: When enable_chunk is False, num_seqs budget is not updated.
|
||||
# assert budget.num_curr_seqs == 1
|
||||
# Both should be preempted, not swapped.
|
||||
assert output.blocks_to_swap_out == {}
|
||||
assert output.blocks_to_swap_out == []
|
||||
# Nothing is copied.
|
||||
assert output.blocks_to_copy == []
|
||||
|
||||
@ -599,7 +599,7 @@ def test_decode_swap_beam_search():
|
||||
scheduler.block_manager.can_append_slots.side_effect = (
|
||||
cannot_append_second_group)
|
||||
scheduler.block_manager.swap_out = MagicMock()
|
||||
expected_swap_mapping = {"5": "7"}
|
||||
expected_swap_mapping = [("5", "7")]
|
||||
scheduler.block_manager.swap_out.return_value = expected_swap_mapping
|
||||
|
||||
remainig_running, output = scheduler._schedule_running(
|
||||
@ -647,7 +647,7 @@ def test_schedule_decode_blocks_to_copy_update():
|
||||
assert len(output.preempted) == 0
|
||||
assert len(output.swapped_out) == 0
|
||||
# Nothing is preempted.
|
||||
assert output.blocks_to_swap_out == {}
|
||||
assert output.blocks_to_swap_out == []
|
||||
# Since append_slot returns the source -> dist mapping, it should
|
||||
# applied.
|
||||
assert output.blocks_to_copy == [(2, 3)]
|
||||
@ -658,7 +658,7 @@ def test_schedule_swapped_simple():
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out = {}
|
||||
blocks_to_swap_out = []
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
@ -674,9 +674,9 @@ def test_schedule_swapped_simple():
|
||||
assert len(output.decode_seq_groups) == 1
|
||||
assert len(output.prefill_seq_groups) == 0
|
||||
# swap in is the reverse of swap out
|
||||
blocks_to_swap_in_reverse = {}
|
||||
for swapin, swapout in output.blocks_to_swap_in.items():
|
||||
blocks_to_swap_in_reverse[swapout] = swapin
|
||||
blocks_to_swap_in_reverse = []
|
||||
for swapin, swapout in output.blocks_to_swap_in:
|
||||
blocks_to_swap_in_reverse.append((swapout, swapin))
|
||||
assert blocks_to_swap_out == blocks_to_swap_in_reverse
|
||||
|
||||
|
||||
@ -685,7 +685,7 @@ def test_schedule_swapped_max_token_budget():
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out = {}
|
||||
blocks_to_swap_out = []
|
||||
for _ in range(2):
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
@ -719,7 +719,7 @@ def test_schedule_swapped_max_seqs():
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out = {}
|
||||
blocks_to_swap_out = []
|
||||
for i in range(4):
|
||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
@ -752,7 +752,7 @@ def test_schedule_swapped_max_loras():
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = set()
|
||||
blocks_to_swap_out = {}
|
||||
blocks_to_swap_out = []
|
||||
for i in range(2):
|
||||
_, seq_group = create_dummy_prompt(str(i),
|
||||
prompt_length=60,
|
||||
@ -781,7 +781,7 @@ def test_schedule_swapped_cannot_swap_in():
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out = {}
|
||||
blocks_to_swap_out = []
|
||||
for _ in range(2):
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
@ -808,7 +808,7 @@ def test_infeasible_swap():
|
||||
swapped = deque()
|
||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
curr_loras = None
|
||||
blocks_to_swap_out = {}
|
||||
blocks_to_swap_out = []
|
||||
for _ in range(2):
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
@ -839,7 +839,7 @@ def test_schedule_swapped_blocks_to_copy():
|
||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||
scheduler._allocate_and_set_running(seq_group)
|
||||
append_new_token_seq_group(60, seq_group, 1)
|
||||
blocks_to_swap_out = {}
|
||||
blocks_to_swap_out = []
|
||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||
swapped.append(seq_group)
|
||||
|
||||
|
@ -315,7 +315,10 @@ def test_swap_blocks(
|
||||
else:
|
||||
dst_blocks = random.sample(range(num_blocks), num_mappings)
|
||||
|
||||
block_mapping = dict(zip(src_blocks, dst_blocks))
|
||||
block_mapping = list(zip(src_blocks, dst_blocks))
|
||||
block_mapping_tensor = torch.tensor(block_mapping,
|
||||
dtype=torch.int64,
|
||||
device="cpu").view(-1, 2)
|
||||
|
||||
# Create the KV caches on the first device.
|
||||
src_key_caches, src_value_caches = kv_cache_factory(
|
||||
@ -331,10 +334,12 @@ def test_swap_blocks(
|
||||
src_value_caches_clone = src_value_caches[0].clone()
|
||||
|
||||
# Call the swap_blocks kernel.
|
||||
ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
|
||||
ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping)
|
||||
ops.swap_blocks(src_key_caches[0], dist_key_caches[0],
|
||||
block_mapping_tensor)
|
||||
ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
|
||||
block_mapping_tensor)
|
||||
|
||||
for src, dst in block_mapping.items():
|
||||
for src, dst in block_mapping:
|
||||
assert torch.allclose(src_key_caches_clone[src].cpu(),
|
||||
dist_key_caches[0][dst].cpu())
|
||||
assert torch.allclose(src_value_caches_clone[src].cpu(),
|
||||
|
@ -54,10 +54,10 @@ def test_swap() -> None:
|
||||
a.cuda(), b.cuda(), rtol=0.0, atol=0.0)
|
||||
|
||||
# Test swap out.
|
||||
blocks_to_swap_out = {3: 72, 56: 35, 84: 34}
|
||||
blocks_to_swap_out = [(3, 72), (56, 35), (84, 34)]
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=[],
|
||||
blocks_to_swap_in={},
|
||||
blocks_to_swap_in=[],
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=[],
|
||||
)
|
||||
@ -66,24 +66,24 @@ def test_swap() -> None:
|
||||
for i in range(num_layers):
|
||||
gpu_key_cache, gpu_value_cache = gpu_cache[i]
|
||||
cpu_key_cache, cpu_value_cache = cpu_cache[i]
|
||||
for src, dst in blocks_to_swap_out.items():
|
||||
for src, dst in blocks_to_swap_out:
|
||||
assert allclose(gpu_key_cache[src], cpu_key_cache[dst])
|
||||
assert allclose(gpu_value_cache[src], cpu_value_cache[dst])
|
||||
|
||||
# Test swap in.
|
||||
execute_model_req.blocks_to_swap_out = {}
|
||||
execute_model_req.blocks_to_swap_in = {
|
||||
19: 45,
|
||||
67: 23,
|
||||
12: 78,
|
||||
40: 99,
|
||||
1: 71
|
||||
}
|
||||
execute_model_req.blocks_to_swap_out = []
|
||||
execute_model_req.blocks_to_swap_in = [
|
||||
(19, 45),
|
||||
(67, 23),
|
||||
(12, 78),
|
||||
(40, 99),
|
||||
(1, 71),
|
||||
]
|
||||
worker.execute_model(execute_model_req=execute_model_req)
|
||||
|
||||
for i in range(num_layers):
|
||||
gpu_key_cache, gpu_value_cache = gpu_cache[i]
|
||||
cpu_key_cache, cpu_value_cache = cpu_cache[i]
|
||||
for src, dst in execute_model_req.blocks_to_swap_in.items():
|
||||
for src, dst in execute_model_req.blocks_to_swap_in:
|
||||
assert allclose(gpu_key_cache[dst], cpu_key_cache[src])
|
||||
assert allclose(gpu_value_cache[dst], cpu_value_cache[src])
|
||||
|
@ -39,7 +39,7 @@ class AttentionBackend(ABC):
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: Dict[int, int],
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -5,7 +5,7 @@ XFormers backend. The duplicated code will be removed once we use flash-attn or
|
||||
flashinfer for all the attention operations.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
@ -45,7 +45,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: Dict[int, int],
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
|
@ -39,7 +39,7 @@ class FlashInferBackend(AttentionBackend):
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: Dict[int, int],
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Attention layer ROCm GPUs."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
@ -43,7 +43,7 @@ class ROCmFlashAttentionBackend(AttentionBackend):
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: Dict[int, int],
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
""" Attention layer with torch scaled_dot_product_attention
|
||||
and PagedAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
from typing import List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
@ -41,7 +41,7 @@ class TorchSDPABackend(AttentionBackend):
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: Dict[int, int],
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@ -196,7 +196,7 @@ class PagedAttention:
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: Dict[int, int],
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
src_key_cache = src_kv_cache[0]
|
||||
dst_key_cache = dst_kv_cache[0]
|
||||
|
@ -473,11 +473,12 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
|
||||
def swap_in(self,
|
||||
seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int = 0) -> Dict[int, int]:
|
||||
num_lookahead_slots: int = 0) -> List[Tuple[int, int]]:
|
||||
assert (num_lookahead_slots == 0
|
||||
), "BlockSpaceManagerV1 does not support lookahead allocation"
|
||||
|
||||
# CPU block -> GPU block.
|
||||
# dict is efficient in lookup `if cpu_block in mapping`
|
||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
||||
new_block_table: BlockTable = []
|
||||
@ -500,14 +501,16 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
cpu_block.block_number: gpu_block.block_number
|
||||
for cpu_block, gpu_block in mapping.items()
|
||||
}
|
||||
return block_number_mapping
|
||||
# convert to list of tuples once here
|
||||
return list(block_number_mapping.items())
|
||||
|
||||
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
|
||||
blocks = self._get_physical_blocks(seq_group)
|
||||
return len(blocks) <= self.cpu_allocator.get_num_free_blocks()
|
||||
|
||||
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
||||
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
|
||||
# GPU block -> CPU block.
|
||||
# dict is efficient in lookup `if gpu_block in mapping`
|
||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
new_block_table: BlockTable = []
|
||||
@ -530,7 +533,8 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
gpu_block.block_number: cpu_block.block_number
|
||||
for gpu_block, cpu_block in mapping.items()
|
||||
}
|
||||
return block_number_mapping
|
||||
# convert to list of tuples once here
|
||||
return list(block_number_mapping.items())
|
||||
|
||||
def _free_block_table(self, block_table: BlockTable) -> None:
|
||||
# when using a sliding window, each seq will only use up
|
||||
|
@ -243,13 +243,13 @@ class BlockSpaceManagerV2(BlockSpaceManager):
|
||||
return AllocStatus.LATER
|
||||
|
||||
def swap_in(self, seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int) -> Dict[int, int]:
|
||||
num_lookahead_slots: int) -> List[Tuple[int, int]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
|
||||
return False
|
||||
|
||||
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
||||
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_num_free_gpu_blocks(self) -> int:
|
||||
|
@ -1,6 +1,6 @@
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List
|
||||
from typing import List
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Tuple
|
||||
|
||||
@ -69,7 +69,7 @@ class BlockSpaceManager(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def swap_in(self, seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int) -> Dict[int, int]:
|
||||
num_lookahead_slots: int) -> List[Tuple[int, int]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -77,7 +77,7 @@ class BlockSpaceManager(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
||||
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
@ -117,10 +117,10 @@ class SchedulerOutputs:
|
||||
num_prefill_groups: int
|
||||
# Total number of batched tokens.
|
||||
num_batched_tokens: int
|
||||
# Blocks to swap in. Dict of CPU -> GPU block number.
|
||||
blocks_to_swap_in: Dict[int, int]
|
||||
# Blocks to swap out. Dict of GPU -> CPU block number.
|
||||
blocks_to_swap_out: Dict[int, int]
|
||||
# Blocks to swap in. List of CPU -> GPU block number.
|
||||
blocks_to_swap_in: List[Tuple[int, int]]
|
||||
# Blocks to swap out. List of GPU -> CPU block number.
|
||||
blocks_to_swap_out: List[Tuple[int, int]]
|
||||
# Blocks to copy. Source to dest block.
|
||||
blocks_to_copy: List[Tuple[int, int]]
|
||||
# Sequence groups that are going to be ignored.
|
||||
@ -174,7 +174,7 @@ class SchedulerRunningOutputs:
|
||||
# Sequences that are swapped out.
|
||||
swapped_out: List[SequenceGroup]
|
||||
# The blocks to swap out.
|
||||
blocks_to_swap_out: Dict[int, int]
|
||||
blocks_to_swap_out: List[Tuple[int, int]]
|
||||
# The blocks to copy.
|
||||
blocks_to_copy: List[Tuple[int, int]]
|
||||
# The number of slots for lookahead decoding.
|
||||
@ -187,7 +187,7 @@ class SchedulerRunningOutputs:
|
||||
prefill_seq_groups=[],
|
||||
preempted=[],
|
||||
swapped_out=[],
|
||||
blocks_to_swap_out={},
|
||||
blocks_to_swap_out=[],
|
||||
blocks_to_copy=[],
|
||||
num_lookahead_slots=0,
|
||||
)
|
||||
@ -206,7 +206,7 @@ class SchedulerSwappedInOutputs:
|
||||
# phase. I.e., it means the prefill has been chunked.
|
||||
prefill_seq_groups: List[SequenceGroup]
|
||||
# The blocks to swap in.
|
||||
blocks_to_swap_in: Dict[int, int]
|
||||
blocks_to_swap_in: List[Tuple[int, int]]
|
||||
# The blocks to copy.
|
||||
blocks_to_copy: List[Tuple[int, int]]
|
||||
# The number of slots for lookahead decoding.
|
||||
@ -219,7 +219,7 @@ class SchedulerSwappedInOutputs:
|
||||
return SchedulerSwappedInOutputs(
|
||||
decode_seq_groups=[],
|
||||
prefill_seq_groups=[],
|
||||
blocks_to_swap_in={},
|
||||
blocks_to_swap_in=[],
|
||||
blocks_to_copy=[],
|
||||
num_lookahead_slots=0,
|
||||
infeasible_seq_groups=[],
|
||||
@ -392,7 +392,7 @@ class Scheduler:
|
||||
scheduling and SchedulerRunningOutputs.
|
||||
"""
|
||||
# Blocks that need to be swapped or copied before model execution.
|
||||
blocks_to_swap_out: Dict[int, int] = {}
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = []
|
||||
blocks_to_copy: List[Tuple[int, int]] = []
|
||||
|
||||
decode_seq_groups: List[ScheduledSequenceGroup] = []
|
||||
@ -509,7 +509,7 @@ class Scheduler:
|
||||
SchedulerSwappedInOutputs.
|
||||
"""
|
||||
# Blocks that need to be swapped or copied before model execution.
|
||||
blocks_to_swap_in: Dict[int, int] = {}
|
||||
blocks_to_swap_in: List[Tuple[int, int]] = []
|
||||
blocks_to_copy: List[Tuple[int, int]] = []
|
||||
decode_seq_groups: List[ScheduledSequenceGroup] = []
|
||||
prefill_seq_groups: List[ScheduledSequenceGroup] = []
|
||||
@ -1032,7 +1032,7 @@ class Scheduler:
|
||||
def _preempt(
|
||||
self,
|
||||
seq_group: SequenceGroup,
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_swap_out: List[Tuple[int, int]],
|
||||
preemption_mode: Optional[PreemptionMode] = None,
|
||||
) -> PreemptionMode:
|
||||
# If preemption mode is not specified, we determine the mode as follows:
|
||||
@ -1073,24 +1073,24 @@ class Scheduler:
|
||||
def _preempt_by_swap(
|
||||
self,
|
||||
seq_group: SequenceGroup,
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_swap_out: List[Tuple[int, int]],
|
||||
) -> None:
|
||||
self._swap_out(seq_group, blocks_to_swap_out)
|
||||
|
||||
def _swap_in(
|
||||
self,
|
||||
seq_group: SequenceGroup,
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_in: List[Tuple[int, int]],
|
||||
) -> None:
|
||||
mapping = self.block_manager.swap_in(seq_group)
|
||||
blocks_to_swap_in.update(mapping)
|
||||
blocks_to_swap_in.extend(mapping)
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
|
||||
def _swap_out(
|
||||
self,
|
||||
seq_group: SequenceGroup,
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_swap_out: List[Tuple[int, int]],
|
||||
) -> None:
|
||||
if not self.block_manager.can_swap_out(seq_group):
|
||||
# FIXME(woosuk): Abort the sequence group instead of aborting the
|
||||
@ -1099,7 +1099,7 @@ class Scheduler:
|
||||
"Aborted due to the lack of CPU swap space. Please increase "
|
||||
"the swap space to avoid this error.")
|
||||
mapping = self.block_manager.swap_out(seq_group)
|
||||
blocks_to_swap_out.update(mapping)
|
||||
blocks_to_swap_out.extend(mapping)
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
seq.status = SequenceStatus.SWAPPED
|
||||
|
||||
|
@ -741,10 +741,10 @@ class ExecuteModelRequest:
|
||||
"""The model execution request."""
|
||||
# The sequence group metadata list.
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
# Blocks to swap in. Dict of CPU -> GPU block number.
|
||||
blocks_to_swap_in: Dict[int, int] = field(default_factory=dict)
|
||||
# Blocks to swap out. Dict of GPU -> CPU block number.
|
||||
blocks_to_swap_out: Dict[int, int] = field(default_factory=dict)
|
||||
# Blocks to swap in. List of CPU -> GPU block number.
|
||||
blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list)
|
||||
# Blocks to swap out. List of GPU -> CPU block number.
|
||||
blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
|
||||
# Blocks to copy. Source to dest block.
|
||||
blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
|
||||
# The number of slots for lookahead decoding.
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""CacheEngine class for managing the KV cache."""
|
||||
from typing import Dict, List
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
@ -67,12 +67,12 @@ class CacheEngine:
|
||||
device=device))
|
||||
return kv_cache
|
||||
|
||||
def swap_in(self, src_to_dst: Dict[int, int]) -> None:
|
||||
def swap_in(self, src_to_dst: torch.Tensor) -> None:
|
||||
for i in range(self.num_layers):
|
||||
self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
|
||||
src_to_dst)
|
||||
|
||||
def swap_out(self, src_to_dst: Dict[int, int]) -> None:
|
||||
def swap_out(self, src_to_dst: torch.Tensor) -> None:
|
||||
for i in range(self.num_layers):
|
||||
self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
|
||||
src_to_dst)
|
||||
|
@ -195,15 +195,14 @@ class Worker(WorkerBase):
|
||||
|
||||
def cache_swap(
|
||||
self,
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_swap_in: torch.Tensor,
|
||||
blocks_to_swap_out: torch.Tensor,
|
||||
blocks_to_copy: torch.Tensor,
|
||||
) -> None:
|
||||
# Issue cache operations.
|
||||
# TODO(woosuk): Profile swapping overhead and optimize if needed.
|
||||
if blocks_to_swap_in:
|
||||
if blocks_to_swap_in.numel() > 0:
|
||||
self.cache_engine.swap_in(blocks_to_swap_in)
|
||||
if blocks_to_swap_out:
|
||||
if blocks_to_swap_out.numel() > 0:
|
||||
self.cache_engine.swap_out(blocks_to_swap_out)
|
||||
if blocks_to_copy.numel() > 0:
|
||||
self.cache_engine.copy(blocks_to_copy)
|
||||
@ -219,12 +218,26 @@ class Worker(WorkerBase):
|
||||
else:
|
||||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||
|
||||
blocks_to_swap_in: torch.Tensor
|
||||
blocks_to_swap_out: torch.Tensor
|
||||
blocks_to_copy: torch.Tensor
|
||||
if self.is_driver_worker:
|
||||
assert seq_group_metadata_list is not None
|
||||
assert execute_model_req is not None
|
||||
num_seq_groups = len(seq_group_metadata_list)
|
||||
blocks_to_swap_in = execute_model_req.blocks_to_swap_in
|
||||
blocks_to_swap_out = execute_model_req.blocks_to_swap_out
|
||||
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
|
||||
# they contain parameters to launch cudamemcpyasync.
|
||||
blocks_to_swap_in = torch.tensor(
|
||||
execute_model_req.blocks_to_swap_in,
|
||||
device="cpu",
|
||||
dtype=torch.int64).view(-1, 2)
|
||||
blocks_to_swap_out = torch.tensor(
|
||||
execute_model_req.blocks_to_swap_out,
|
||||
device="cpu",
|
||||
dtype=torch.int64).view(-1, 2)
|
||||
# `blocks_to_copy` is a gpu tensor. The src and tgt of
|
||||
# blocks to copy are in the same device, and `blocks_to_copy`
|
||||
# can be used directly within cuda kernels.
|
||||
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
|
||||
device=self.device,
|
||||
dtype=torch.int64).view(-1, 2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user