[Core][Optimization] change python dict to pytorch tensor for blocks to swap (#4659)

This commit is contained in:
youkaichao 2024-05-08 12:07:05 -07:00 committed by GitHub
parent ad932a221d
commit 20cfcdec99
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 137 additions and 109 deletions

View File

@ -8,12 +8,12 @@
void swap_blocks( void swap_blocks(
torch::Tensor& src, torch::Tensor& src,
torch::Tensor& dst, torch::Tensor& dst,
const std::map<int64_t, int64_t>& block_mapping); const torch::Tensor& block_mapping);
void copy_blocks( void copy_blocks(
std::vector<torch::Tensor>& key_caches, std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches, std::vector<torch::Tensor>& value_caches,
torch::Tensor& block_mapping); const torch::Tensor& block_mapping);
void reshape_and_cache( void reshape_and_cache(
torch::Tensor& key, torch::Tensor& key,

View File

@ -23,7 +23,7 @@
void swap_blocks( void swap_blocks(
torch::Tensor& src, torch::Tensor& src,
torch::Tensor& dst, torch::Tensor& dst,
const std::map<int64_t, int64_t>& block_mapping) { const torch::Tensor& block_mapping) {
torch::Device src_device = src.device(); torch::Device src_device = src.device();
torch::Device dst_device = dst.device(); torch::Device dst_device = dst.device();
cudaMemcpyKind memcpy_type; cudaMemcpyKind memcpy_type;
@ -40,6 +40,11 @@ void swap_blocks(
TORCH_CHECK(false, "Invalid device combination"); TORCH_CHECK(false, "Invalid device combination");
} }
// NOTE(youkaichao): keep in mind that `block_mapping` should be
// a cpu tensor, otherwise every `item` call will require a gpu-cpu
// synchronization.
TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
char *src_ptr = static_cast<char*>(src.data_ptr()); char *src_ptr = static_cast<char*>(src.data_ptr());
char *dst_ptr = static_cast<char*>(dst.data_ptr()); char *dst_ptr = static_cast<char*>(dst.data_ptr());
@ -47,9 +52,10 @@ void swap_blocks(
const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device); const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// NOTE(woosuk): This can be slow if the number of blocks is large. // NOTE(woosuk): This can be slow if the number of blocks is large.
for (const auto& pair : block_mapping) { const int64_t num_blocks = block_mapping.size(0);
int64_t src_block_number = pair.first; for (size_t i = 0; i < num_blocks; i++) {
int64_t dst_block_number = pair.second; int64_t src_block_number = block_mapping[i][0].item<int64_t>();
int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
int64_t src_offset = src_block_number * block_size_in_bytes; int64_t src_offset = src_block_number * block_size_in_bytes;
int64_t dst_offset = dst_block_number * block_size_in_bytes; int64_t dst_offset = dst_block_number * block_size_in_bytes;
cudaMemcpyAsync( cudaMemcpyAsync(
@ -97,7 +103,7 @@ __global__ void copy_blocks_kernel(
void copy_blocks( void copy_blocks(
std::vector<torch::Tensor>& key_caches, std::vector<torch::Tensor>& key_caches,
std::vector<torch::Tensor>& value_caches, std::vector<torch::Tensor>& value_caches,
torch::Tensor& block_mapping) { const torch::Tensor& block_mapping) {
int num_layers = key_caches.size(); int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size()); TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) { if (num_layers == 0) {

View File

@ -83,7 +83,7 @@ void reshape_and_cache_cpu_impl(
void copy_blocks(std::vector<torch::Tensor> &key_caches, void copy_blocks(std::vector<torch::Tensor> &key_caches,
std::vector<torch::Tensor> &value_caches, std::vector<torch::Tensor> &value_caches,
torch::Tensor& block_mapping) { const torch::Tensor& block_mapping) {
int num_layers = key_caches.size(); int num_layers = key_caches.size();
TORCH_CHECK(num_layers == value_caches.size()); TORCH_CHECK(num_layers == value_caches.size());
if (num_layers == 0) { if (num_layers == 0) {
@ -128,6 +128,6 @@ void reshape_and_cache(torch::Tensor &key, torch::Tensor &value,
} }
void swap_blocks(torch::Tensor &src, torch::Tensor &dst, void swap_blocks(torch::Tensor &src, torch::Tensor &dst,
const std::map<int64_t, int64_t> &block_mapping) { const torch::Tensor&block_mapping) {
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.") TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
} }

View File

@ -219,7 +219,7 @@ def test_swap():
before_cpu_blocks = block_manager.get_num_free_cpu_blocks() before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
before_gpu_blocks = block_manager.get_num_free_gpu_blocks() before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
mapping = block_manager.swap_out(seq_group) mapping = block_manager.swap_out(seq_group)
assert list(mapping.keys()) == gpu_blocks assert [x[0] for x in mapping] == gpu_blocks
after_cpu_blocks = block_manager.get_num_free_cpu_blocks() after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
after_gpu_blocks = block_manager.get_num_free_gpu_blocks() after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks) assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks)
@ -232,7 +232,7 @@ def test_swap():
before_cpu_blocks = block_manager.get_num_free_cpu_blocks() before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
before_gpu_blocks = block_manager.get_num_free_gpu_blocks() before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
mapping = block_manager.swap_in(seq_group) mapping = block_manager.swap_in(seq_group)
assert list(mapping.keys()) == cpu_blocks assert [x[0] for x in mapping] == cpu_blocks
after_cpu_blocks = block_manager.get_num_free_cpu_blocks() after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
after_gpu_blocks = block_manager.get_num_free_gpu_blocks() after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks

View File

@ -355,8 +355,8 @@ def test_swap():
_, out = schedule_and_update_computed_tokens(scheduler) _, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 0 assert len(out.scheduled_seq_groups) == 0
assert out.num_batched_tokens == 0 assert out.num_batched_tokens == 0
assert out.blocks_to_swap_out != {} assert out.blocks_to_swap_out != []
assert out.blocks_to_swap_in == {} assert out.blocks_to_swap_in == []
# Add 1 more task. Swap should be prioritized over new prefill. # Add 1 more task. Swap should be prioritized over new prefill.
_, seq_group = create_dummy_prompt("2", prompt_length=60) _, seq_group = create_dummy_prompt("2", prompt_length=60)
@ -365,8 +365,8 @@ def test_swap():
assert len(out.scheduled_seq_groups) == 1 assert len(out.scheduled_seq_groups) == 1
# 3 decodes. It is swapped in. # 3 decodes. It is swapped in.
assert out.num_batched_tokens == 30 assert out.num_batched_tokens == 30
assert out.blocks_to_swap_in != {} assert out.blocks_to_swap_in != []
assert out.blocks_to_swap_out == {} assert out.blocks_to_swap_out == []
def test_running_prefill_prioritized_over_swap(): def test_running_prefill_prioritized_over_swap():
@ -406,8 +406,8 @@ def test_running_prefill_prioritized_over_swap():
_, out = schedule_and_update_computed_tokens(scheduler) _, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 0 assert len(out.scheduled_seq_groups) == 0
assert out.num_batched_tokens == 0 assert out.num_batched_tokens == 0
assert out.blocks_to_swap_out != {} assert out.blocks_to_swap_out != []
assert out.blocks_to_swap_in == {} assert out.blocks_to_swap_in == []
# Add 1 more task. Swap is not possible, so prefill is running. # Add 1 more task. Swap is not possible, so prefill is running.
scheduler.block_manager.can_swap_in = MagicMock() scheduler.block_manager.can_swap_in = MagicMock()
@ -419,8 +419,8 @@ def test_running_prefill_prioritized_over_swap():
assert len(out.scheduled_seq_groups) == 1 assert len(out.scheduled_seq_groups) == 1
# 3 decodes. It is swapped in. # 3 decodes. It is swapped in.
assert out.num_batched_tokens == 30 assert out.num_batched_tokens == 30
assert out.blocks_to_swap_in == {} assert out.blocks_to_swap_in == []
assert out.blocks_to_swap_out == {} assert out.blocks_to_swap_out == []
assert out.scheduled_seq_groups[0].seq_group == seq_group2 assert out.scheduled_seq_groups[0].seq_group == seq_group2
# Now although swap is possible, running prefill is prioritized. # Now although swap is possible, running prefill is prioritized.
@ -429,8 +429,8 @@ def test_running_prefill_prioritized_over_swap():
assert len(out.scheduled_seq_groups) == 1 assert len(out.scheduled_seq_groups) == 1
# 3 decodes. It is swapped in. # 3 decodes. It is swapped in.
assert out.num_batched_tokens == 30 assert out.num_batched_tokens == 30
assert out.blocks_to_swap_in == {} assert out.blocks_to_swap_in == []
assert out.blocks_to_swap_out == {} assert out.blocks_to_swap_out == []
assert not seq_group2.is_prefill() assert not seq_group2.is_prefill()
assert out.scheduled_seq_groups[0].seq_group == seq_group2 assert out.scheduled_seq_groups[0].seq_group == seq_group2
append_new_token(seq_group2, 1) append_new_token(seq_group2, 1)
@ -440,8 +440,8 @@ def test_running_prefill_prioritized_over_swap():
assert len(out.scheduled_seq_groups) == 1 assert len(out.scheduled_seq_groups) == 1
# 3 decodes. It is swapped in. # 3 decodes. It is swapped in.
assert out.num_batched_tokens == 1 assert out.num_batched_tokens == 1
assert out.blocks_to_swap_in == {} assert out.blocks_to_swap_in == []
assert out.blocks_to_swap_out == {} assert out.blocks_to_swap_out == []
assert not seq_group2.is_prefill() assert not seq_group2.is_prefill()
assert out.scheduled_seq_groups[0].seq_group == seq_group2 assert out.scheduled_seq_groups[0].seq_group == seq_group2
append_new_token(seq_group2, 1) append_new_token(seq_group2, 1)
@ -451,8 +451,8 @@ def test_running_prefill_prioritized_over_swap():
_, out = schedule_and_update_computed_tokens(scheduler) _, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 1 assert len(out.scheduled_seq_groups) == 1
assert out.num_batched_tokens == 30 assert out.num_batched_tokens == 30
assert out.blocks_to_swap_in != {} assert out.blocks_to_swap_in != []
assert out.blocks_to_swap_out == {} assert out.blocks_to_swap_out == []
def test_chunked_prefill_preempt(): def test_chunked_prefill_preempt():
@ -493,8 +493,8 @@ def test_chunked_prefill_preempt():
_, out = schedule_and_update_computed_tokens(scheduler) _, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 0 assert len(out.scheduled_seq_groups) == 0
assert out.num_batched_tokens == 0 assert out.num_batched_tokens == 0
assert out.blocks_to_swap_out == {} assert out.blocks_to_swap_out == []
assert out.blocks_to_swap_in == {} assert out.blocks_to_swap_in == []
# Make sure we can reschedule preempted request. # Make sure we can reschedule preempted request.
_, out = schedule_and_update_computed_tokens(scheduler) _, out = schedule_and_update_computed_tokens(scheduler)

View File

@ -293,8 +293,8 @@ def test_swapped_out_prioritized():
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) == 2 assert len(out.scheduled_seq_groups) == 2
assert out.num_batched_tokens == 2 assert out.num_batched_tokens == 2
assert out.blocks_to_swap_out != {} assert out.blocks_to_swap_out != []
assert out.blocks_to_swap_in == {} assert out.blocks_to_swap_in == []
append_new_token(out, 1) append_new_token(out, 1)
# Add 1 more task. Swap should be prioritized over prefill. # Add 1 more task. Swap should be prioritized over prefill.
@ -305,8 +305,8 @@ def test_swapped_out_prioritized():
assert len(out.scheduled_seq_groups) == 3 assert len(out.scheduled_seq_groups) == 3
# 3 decodes. It is swapped in. # 3 decodes. It is swapped in.
assert out.num_batched_tokens == 3 assert out.num_batched_tokens == 3
assert out.blocks_to_swap_in != {} assert out.blocks_to_swap_in != []
assert out.blocks_to_swap_out == {} assert out.blocks_to_swap_out == []
def initialize_scheduler(*, def initialize_scheduler(*,
@ -566,7 +566,7 @@ def test_decode_schedule_preempted():
# NOTE: When enable_chunk is False, num_seqs budget is not updated. # NOTE: When enable_chunk is False, num_seqs budget is not updated.
# assert budget.num_curr_seqs == 1 # assert budget.num_curr_seqs == 1
# Both should be preempted, not swapped. # Both should be preempted, not swapped.
assert output.blocks_to_swap_out == {} assert output.blocks_to_swap_out == []
# Nothing is copied. # Nothing is copied.
assert output.blocks_to_copy == [] assert output.blocks_to_copy == []
@ -599,7 +599,7 @@ def test_decode_swap_beam_search():
scheduler.block_manager.can_append_slots.side_effect = ( scheduler.block_manager.can_append_slots.side_effect = (
cannot_append_second_group) cannot_append_second_group)
scheduler.block_manager.swap_out = MagicMock() scheduler.block_manager.swap_out = MagicMock()
expected_swap_mapping = {"5": "7"} expected_swap_mapping = [("5", "7")]
scheduler.block_manager.swap_out.return_value = expected_swap_mapping scheduler.block_manager.swap_out.return_value = expected_swap_mapping
remainig_running, output = scheduler._schedule_running( remainig_running, output = scheduler._schedule_running(
@ -647,7 +647,7 @@ def test_schedule_decode_blocks_to_copy_update():
assert len(output.preempted) == 0 assert len(output.preempted) == 0
assert len(output.swapped_out) == 0 assert len(output.swapped_out) == 0
# Nothing is preempted. # Nothing is preempted.
assert output.blocks_to_swap_out == {} assert output.blocks_to_swap_out == []
# Since append_slot returns the source -> dist mapping, it should # Since append_slot returns the source -> dist mapping, it should
# applied. # applied.
assert output.blocks_to_copy == [(2, 3)] assert output.blocks_to_copy == [(2, 3)]
@ -658,7 +658,7 @@ def test_schedule_swapped_simple():
swapped = deque() swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs") policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None curr_loras = None
blocks_to_swap_out = {} blocks_to_swap_out = []
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group) scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1) append_new_token_seq_group(60, seq_group, 1)
@ -674,9 +674,9 @@ def test_schedule_swapped_simple():
assert len(output.decode_seq_groups) == 1 assert len(output.decode_seq_groups) == 1
assert len(output.prefill_seq_groups) == 0 assert len(output.prefill_seq_groups) == 0
# swap in is the reverse of swap out # swap in is the reverse of swap out
blocks_to_swap_in_reverse = {} blocks_to_swap_in_reverse = []
for swapin, swapout in output.blocks_to_swap_in.items(): for swapin, swapout in output.blocks_to_swap_in:
blocks_to_swap_in_reverse[swapout] = swapin blocks_to_swap_in_reverse.append((swapout, swapin))
assert blocks_to_swap_out == blocks_to_swap_in_reverse assert blocks_to_swap_out == blocks_to_swap_in_reverse
@ -685,7 +685,7 @@ def test_schedule_swapped_max_token_budget():
swapped = deque() swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs") policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None curr_loras = None
blocks_to_swap_out = {} blocks_to_swap_out = []
for _ in range(2): for _ in range(2):
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group) scheduler._allocate_and_set_running(seq_group)
@ -719,7 +719,7 @@ def test_schedule_swapped_max_seqs():
swapped = deque() swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs") policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None curr_loras = None
blocks_to_swap_out = {} blocks_to_swap_out = []
for i in range(4): for i in range(4):
_, seq_group = create_dummy_prompt(str(i), prompt_length=60) _, seq_group = create_dummy_prompt(str(i), prompt_length=60)
scheduler._allocate_and_set_running(seq_group) scheduler._allocate_and_set_running(seq_group)
@ -752,7 +752,7 @@ def test_schedule_swapped_max_loras():
swapped = deque() swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs") policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = set() curr_loras = set()
blocks_to_swap_out = {} blocks_to_swap_out = []
for i in range(2): for i in range(2):
_, seq_group = create_dummy_prompt(str(i), _, seq_group = create_dummy_prompt(str(i),
prompt_length=60, prompt_length=60,
@ -781,7 +781,7 @@ def test_schedule_swapped_cannot_swap_in():
swapped = deque() swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs") policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None curr_loras = None
blocks_to_swap_out = {} blocks_to_swap_out = []
for _ in range(2): for _ in range(2):
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group) scheduler._allocate_and_set_running(seq_group)
@ -808,7 +808,7 @@ def test_infeasible_swap():
swapped = deque() swapped = deque()
policy = PolicyFactory.get_policy(policy_name="fcfs") policy = PolicyFactory.get_policy(policy_name="fcfs")
curr_loras = None curr_loras = None
blocks_to_swap_out = {} blocks_to_swap_out = []
for _ in range(2): for _ in range(2):
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group) scheduler._allocate_and_set_running(seq_group)
@ -839,7 +839,7 @@ def test_schedule_swapped_blocks_to_copy():
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
scheduler._allocate_and_set_running(seq_group) scheduler._allocate_and_set_running(seq_group)
append_new_token_seq_group(60, seq_group, 1) append_new_token_seq_group(60, seq_group, 1)
blocks_to_swap_out = {} blocks_to_swap_out = []
scheduler._swap_out(seq_group, blocks_to_swap_out) scheduler._swap_out(seq_group, blocks_to_swap_out)
swapped.append(seq_group) swapped.append(seq_group)

View File

@ -315,7 +315,10 @@ def test_swap_blocks(
else: else:
dst_blocks = random.sample(range(num_blocks), num_mappings) dst_blocks = random.sample(range(num_blocks), num_mappings)
block_mapping = dict(zip(src_blocks, dst_blocks)) block_mapping = list(zip(src_blocks, dst_blocks))
block_mapping_tensor = torch.tensor(block_mapping,
dtype=torch.int64,
device="cpu").view(-1, 2)
# Create the KV caches on the first device. # Create the KV caches on the first device.
src_key_caches, src_value_caches = kv_cache_factory( src_key_caches, src_value_caches = kv_cache_factory(
@ -331,10 +334,12 @@ def test_swap_blocks(
src_value_caches_clone = src_value_caches[0].clone() src_value_caches_clone = src_value_caches[0].clone()
# Call the swap_blocks kernel. # Call the swap_blocks kernel.
ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping) ops.swap_blocks(src_key_caches[0], dist_key_caches[0],
ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping) block_mapping_tensor)
ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
block_mapping_tensor)
for src, dst in block_mapping.items(): for src, dst in block_mapping:
assert torch.allclose(src_key_caches_clone[src].cpu(), assert torch.allclose(src_key_caches_clone[src].cpu(),
dist_key_caches[0][dst].cpu()) dist_key_caches[0][dst].cpu())
assert torch.allclose(src_value_caches_clone[src].cpu(), assert torch.allclose(src_value_caches_clone[src].cpu(),

View File

@ -54,10 +54,10 @@ def test_swap() -> None:
a.cuda(), b.cuda(), rtol=0.0, atol=0.0) a.cuda(), b.cuda(), rtol=0.0, atol=0.0)
# Test swap out. # Test swap out.
blocks_to_swap_out = {3: 72, 56: 35, 84: 34} blocks_to_swap_out = [(3, 72), (56, 35), (84, 34)]
execute_model_req = ExecuteModelRequest( execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=[], seq_group_metadata_list=[],
blocks_to_swap_in={}, blocks_to_swap_in=[],
blocks_to_swap_out=blocks_to_swap_out, blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=[], blocks_to_copy=[],
) )
@ -66,24 +66,24 @@ def test_swap() -> None:
for i in range(num_layers): for i in range(num_layers):
gpu_key_cache, gpu_value_cache = gpu_cache[i] gpu_key_cache, gpu_value_cache = gpu_cache[i]
cpu_key_cache, cpu_value_cache = cpu_cache[i] cpu_key_cache, cpu_value_cache = cpu_cache[i]
for src, dst in blocks_to_swap_out.items(): for src, dst in blocks_to_swap_out:
assert allclose(gpu_key_cache[src], cpu_key_cache[dst]) assert allclose(gpu_key_cache[src], cpu_key_cache[dst])
assert allclose(gpu_value_cache[src], cpu_value_cache[dst]) assert allclose(gpu_value_cache[src], cpu_value_cache[dst])
# Test swap in. # Test swap in.
execute_model_req.blocks_to_swap_out = {} execute_model_req.blocks_to_swap_out = []
execute_model_req.blocks_to_swap_in = { execute_model_req.blocks_to_swap_in = [
19: 45, (19, 45),
67: 23, (67, 23),
12: 78, (12, 78),
40: 99, (40, 99),
1: 71 (1, 71),
} ]
worker.execute_model(execute_model_req=execute_model_req) worker.execute_model(execute_model_req=execute_model_req)
for i in range(num_layers): for i in range(num_layers):
gpu_key_cache, gpu_value_cache = gpu_cache[i] gpu_key_cache, gpu_value_cache = gpu_cache[i]
cpu_key_cache, cpu_value_cache = cpu_cache[i] cpu_key_cache, cpu_value_cache = cpu_cache[i]
for src, dst in execute_model_req.blocks_to_swap_in.items(): for src, dst in execute_model_req.blocks_to_swap_in:
assert allclose(gpu_key_cache[dst], cpu_key_cache[src]) assert allclose(gpu_key_cache[dst], cpu_key_cache[src])
assert allclose(gpu_value_cache[dst], cpu_value_cache[src]) assert allclose(gpu_value_cache[dst], cpu_value_cache[src])

View File

@ -39,7 +39,7 @@ class AttentionBackend(ABC):
def swap_blocks( def swap_blocks(
src_kv_cache: torch.Tensor, src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int], src_to_dst: torch.Tensor,
) -> None: ) -> None:
raise NotImplementedError raise NotImplementedError

View File

@ -5,7 +5,7 @@ XFormers backend. The duplicated code will be removed once we use flash-attn or
flashinfer for all the attention operations. flashinfer for all the attention operations.
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type from typing import List, Optional, Tuple, Type
import torch import torch
from flash_attn import flash_attn_varlen_func from flash_attn import flash_attn_varlen_func
@ -45,7 +45,7 @@ class FlashAttentionBackend(AttentionBackend):
def swap_blocks( def swap_blocks(
src_kv_cache: torch.Tensor, src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int], src_to_dst: torch.Tensor,
) -> None: ) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

View File

@ -39,7 +39,7 @@ class FlashInferBackend(AttentionBackend):
def swap_blocks( def swap_blocks(
src_kv_cache: torch.Tensor, src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int], src_to_dst: torch.Tensor,
) -> None: ) -> None:
raise NotImplementedError raise NotImplementedError

View File

@ -1,6 +1,6 @@
"""Attention layer ROCm GPUs.""" """Attention layer ROCm GPUs."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type from typing import List, Optional, Tuple, Type
import torch import torch
@ -43,7 +43,7 @@ class ROCmFlashAttentionBackend(AttentionBackend):
def swap_blocks( def swap_blocks(
src_kv_cache: torch.Tensor, src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int], src_to_dst: torch.Tensor,
) -> None: ) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

View File

@ -1,7 +1,7 @@
""" Attention layer with torch scaled_dot_product_attention """ Attention layer with torch scaled_dot_product_attention
and PagedAttention.""" and PagedAttention."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type from typing import List, Optional, Tuple, Type
import torch import torch
from torch.nn.functional import scaled_dot_product_attention from torch.nn.functional import scaled_dot_product_attention
@ -41,7 +41,7 @@ class TorchSDPABackend(AttentionBackend):
def swap_blocks( def swap_blocks(
src_kv_cache: torch.Tensor, src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int], src_to_dst: torch.Tensor,
) -> None: ) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
@ -196,7 +196,7 @@ class PagedAttention:
def swap_blocks( def swap_blocks(
src_kv_cache: torch.Tensor, src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int], src_to_dst: torch.Tensor,
) -> None: ) -> None:
src_key_cache = src_kv_cache[0] src_key_cache = src_kv_cache[0]
dst_key_cache = dst_kv_cache[0] dst_key_cache = dst_kv_cache[0]

View File

@ -473,11 +473,12 @@ class BlockSpaceManagerV1(BlockSpaceManager):
def swap_in(self, def swap_in(self,
seq_group: SequenceGroup, seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> Dict[int, int]: num_lookahead_slots: int = 0) -> List[Tuple[int, int]]:
assert (num_lookahead_slots == 0 assert (num_lookahead_slots == 0
), "BlockSpaceManagerV1 does not support lookahead allocation" ), "BlockSpaceManagerV1 does not support lookahead allocation"
# CPU block -> GPU block. # CPU block -> GPU block.
# dict is efficient in lookup `if cpu_block in mapping`
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
new_block_table: BlockTable = [] new_block_table: BlockTable = []
@ -500,14 +501,16 @@ class BlockSpaceManagerV1(BlockSpaceManager):
cpu_block.block_number: gpu_block.block_number cpu_block.block_number: gpu_block.block_number
for cpu_block, gpu_block in mapping.items() for cpu_block, gpu_block in mapping.items()
} }
return block_number_mapping # convert to list of tuples once here
return list(block_number_mapping.items())
def can_swap_out(self, seq_group: SequenceGroup) -> bool: def can_swap_out(self, seq_group: SequenceGroup) -> bool:
blocks = self._get_physical_blocks(seq_group) blocks = self._get_physical_blocks(seq_group)
return len(blocks) <= self.cpu_allocator.get_num_free_blocks() return len(blocks) <= self.cpu_allocator.get_num_free_blocks()
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
# GPU block -> CPU block. # GPU block -> CPU block.
# dict is efficient in lookup `if gpu_block in mapping`
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
new_block_table: BlockTable = [] new_block_table: BlockTable = []
@ -530,7 +533,8 @@ class BlockSpaceManagerV1(BlockSpaceManager):
gpu_block.block_number: cpu_block.block_number gpu_block.block_number: cpu_block.block_number
for gpu_block, cpu_block in mapping.items() for gpu_block, cpu_block in mapping.items()
} }
return block_number_mapping # convert to list of tuples once here
return list(block_number_mapping.items())
def _free_block_table(self, block_table: BlockTable) -> None: def _free_block_table(self, block_table: BlockTable) -> None:
# when using a sliding window, each seq will only use up # when using a sliding window, each seq will only use up

View File

@ -243,13 +243,13 @@ class BlockSpaceManagerV2(BlockSpaceManager):
return AllocStatus.LATER return AllocStatus.LATER
def swap_in(self, seq_group: SequenceGroup, def swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> Dict[int, int]: num_lookahead_slots: int) -> List[Tuple[int, int]]:
raise NotImplementedError raise NotImplementedError
def can_swap_out(self, seq_group: SequenceGroup) -> bool: def can_swap_out(self, seq_group: SequenceGroup) -> bool:
return False return False
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
raise NotImplementedError raise NotImplementedError
def get_num_free_gpu_blocks(self) -> int: def get_num_free_gpu_blocks(self) -> int:

View File

@ -1,6 +1,6 @@
import enum import enum
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List from typing import List
from typing import Sequence as GenericSequence from typing import Sequence as GenericSequence
from typing import Tuple from typing import Tuple
@ -69,7 +69,7 @@ class BlockSpaceManager(ABC):
@abstractmethod @abstractmethod
def swap_in(self, seq_group: SequenceGroup, def swap_in(self, seq_group: SequenceGroup,
num_lookahead_slots: int) -> Dict[int, int]: num_lookahead_slots: int) -> List[Tuple[int, int]]:
pass pass
@abstractmethod @abstractmethod
@ -77,7 +77,7 @@ class BlockSpaceManager(ABC):
pass pass
@abstractmethod @abstractmethod
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
pass pass
@abstractmethod @abstractmethod

View File

@ -117,10 +117,10 @@ class SchedulerOutputs:
num_prefill_groups: int num_prefill_groups: int
# Total number of batched tokens. # Total number of batched tokens.
num_batched_tokens: int num_batched_tokens: int
# Blocks to swap in. Dict of CPU -> GPU block number. # Blocks to swap in. List of CPU -> GPU block number.
blocks_to_swap_in: Dict[int, int] blocks_to_swap_in: List[Tuple[int, int]]
# Blocks to swap out. Dict of GPU -> CPU block number. # Blocks to swap out. List of GPU -> CPU block number.
blocks_to_swap_out: Dict[int, int] blocks_to_swap_out: List[Tuple[int, int]]
# Blocks to copy. Source to dest block. # Blocks to copy. Source to dest block.
blocks_to_copy: List[Tuple[int, int]] blocks_to_copy: List[Tuple[int, int]]
# Sequence groups that are going to be ignored. # Sequence groups that are going to be ignored.
@ -174,7 +174,7 @@ class SchedulerRunningOutputs:
# Sequences that are swapped out. # Sequences that are swapped out.
swapped_out: List[SequenceGroup] swapped_out: List[SequenceGroup]
# The blocks to swap out. # The blocks to swap out.
blocks_to_swap_out: Dict[int, int] blocks_to_swap_out: List[Tuple[int, int]]
# The blocks to copy. # The blocks to copy.
blocks_to_copy: List[Tuple[int, int]] blocks_to_copy: List[Tuple[int, int]]
# The number of slots for lookahead decoding. # The number of slots for lookahead decoding.
@ -187,7 +187,7 @@ class SchedulerRunningOutputs:
prefill_seq_groups=[], prefill_seq_groups=[],
preempted=[], preempted=[],
swapped_out=[], swapped_out=[],
blocks_to_swap_out={}, blocks_to_swap_out=[],
blocks_to_copy=[], blocks_to_copy=[],
num_lookahead_slots=0, num_lookahead_slots=0,
) )
@ -206,7 +206,7 @@ class SchedulerSwappedInOutputs:
# phase. I.e., it means the prefill has been chunked. # phase. I.e., it means the prefill has been chunked.
prefill_seq_groups: List[SequenceGroup] prefill_seq_groups: List[SequenceGroup]
# The blocks to swap in. # The blocks to swap in.
blocks_to_swap_in: Dict[int, int] blocks_to_swap_in: List[Tuple[int, int]]
# The blocks to copy. # The blocks to copy.
blocks_to_copy: List[Tuple[int, int]] blocks_to_copy: List[Tuple[int, int]]
# The number of slots for lookahead decoding. # The number of slots for lookahead decoding.
@ -219,7 +219,7 @@ class SchedulerSwappedInOutputs:
return SchedulerSwappedInOutputs( return SchedulerSwappedInOutputs(
decode_seq_groups=[], decode_seq_groups=[],
prefill_seq_groups=[], prefill_seq_groups=[],
blocks_to_swap_in={}, blocks_to_swap_in=[],
blocks_to_copy=[], blocks_to_copy=[],
num_lookahead_slots=0, num_lookahead_slots=0,
infeasible_seq_groups=[], infeasible_seq_groups=[],
@ -392,7 +392,7 @@ class Scheduler:
scheduling and SchedulerRunningOutputs. scheduling and SchedulerRunningOutputs.
""" """
# Blocks that need to be swapped or copied before model execution. # Blocks that need to be swapped or copied before model execution.
blocks_to_swap_out: Dict[int, int] = {} blocks_to_swap_out: List[Tuple[int, int]] = []
blocks_to_copy: List[Tuple[int, int]] = [] blocks_to_copy: List[Tuple[int, int]] = []
decode_seq_groups: List[ScheduledSequenceGroup] = [] decode_seq_groups: List[ScheduledSequenceGroup] = []
@ -509,7 +509,7 @@ class Scheduler:
SchedulerSwappedInOutputs. SchedulerSwappedInOutputs.
""" """
# Blocks that need to be swapped or copied before model execution. # Blocks that need to be swapped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {} blocks_to_swap_in: List[Tuple[int, int]] = []
blocks_to_copy: List[Tuple[int, int]] = [] blocks_to_copy: List[Tuple[int, int]] = []
decode_seq_groups: List[ScheduledSequenceGroup] = [] decode_seq_groups: List[ScheduledSequenceGroup] = []
prefill_seq_groups: List[ScheduledSequenceGroup] = [] prefill_seq_groups: List[ScheduledSequenceGroup] = []
@ -1032,7 +1032,7 @@ class Scheduler:
def _preempt( def _preempt(
self, self,
seq_group: SequenceGroup, seq_group: SequenceGroup,
blocks_to_swap_out: Dict[int, int], blocks_to_swap_out: List[Tuple[int, int]],
preemption_mode: Optional[PreemptionMode] = None, preemption_mode: Optional[PreemptionMode] = None,
) -> PreemptionMode: ) -> PreemptionMode:
# If preemption mode is not specified, we determine the mode as follows: # If preemption mode is not specified, we determine the mode as follows:
@ -1073,24 +1073,24 @@ class Scheduler:
def _preempt_by_swap( def _preempt_by_swap(
self, self,
seq_group: SequenceGroup, seq_group: SequenceGroup,
blocks_to_swap_out: Dict[int, int], blocks_to_swap_out: List[Tuple[int, int]],
) -> None: ) -> None:
self._swap_out(seq_group, blocks_to_swap_out) self._swap_out(seq_group, blocks_to_swap_out)
def _swap_in( def _swap_in(
self, self,
seq_group: SequenceGroup, seq_group: SequenceGroup,
blocks_to_swap_in: Dict[int, int], blocks_to_swap_in: List[Tuple[int, int]],
) -> None: ) -> None:
mapping = self.block_manager.swap_in(seq_group) mapping = self.block_manager.swap_in(seq_group)
blocks_to_swap_in.update(mapping) blocks_to_swap_in.extend(mapping)
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
seq.status = SequenceStatus.RUNNING seq.status = SequenceStatus.RUNNING
def _swap_out( def _swap_out(
self, self,
seq_group: SequenceGroup, seq_group: SequenceGroup,
blocks_to_swap_out: Dict[int, int], blocks_to_swap_out: List[Tuple[int, int]],
) -> None: ) -> None:
if not self.block_manager.can_swap_out(seq_group): if not self.block_manager.can_swap_out(seq_group):
# FIXME(woosuk): Abort the sequence group instead of aborting the # FIXME(woosuk): Abort the sequence group instead of aborting the
@ -1099,7 +1099,7 @@ class Scheduler:
"Aborted due to the lack of CPU swap space. Please increase " "Aborted due to the lack of CPU swap space. Please increase "
"the swap space to avoid this error.") "the swap space to avoid this error.")
mapping = self.block_manager.swap_out(seq_group) mapping = self.block_manager.swap_out(seq_group)
blocks_to_swap_out.update(mapping) blocks_to_swap_out.extend(mapping)
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
seq.status = SequenceStatus.SWAPPED seq.status = SequenceStatus.SWAPPED

View File

@ -741,10 +741,10 @@ class ExecuteModelRequest:
"""The model execution request.""" """The model execution request."""
# The sequence group metadata list. # The sequence group metadata list.
seq_group_metadata_list: List[SequenceGroupMetadata] seq_group_metadata_list: List[SequenceGroupMetadata]
# Blocks to swap in. Dict of CPU -> GPU block number. # Blocks to swap in. List of CPU -> GPU block number.
blocks_to_swap_in: Dict[int, int] = field(default_factory=dict) blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list)
# Blocks to swap out. Dict of GPU -> CPU block number. # Blocks to swap out. List of GPU -> CPU block number.
blocks_to_swap_out: Dict[int, int] = field(default_factory=dict) blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
# Blocks to copy. Source to dest block. # Blocks to copy. Source to dest block.
blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list) blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
# The number of slots for lookahead decoding. # The number of slots for lookahead decoding.

View File

@ -1,5 +1,5 @@
"""CacheEngine class for managing the KV cache.""" """CacheEngine class for managing the KV cache."""
from typing import Dict, List from typing import List
import torch import torch
@ -67,12 +67,12 @@ class CacheEngine:
device=device)) device=device))
return kv_cache return kv_cache
def swap_in(self, src_to_dst: Dict[int, int]) -> None: def swap_in(self, src_to_dst: torch.Tensor) -> None:
for i in range(self.num_layers): for i in range(self.num_layers):
self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i], self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
src_to_dst) src_to_dst)
def swap_out(self, src_to_dst: Dict[int, int]) -> None: def swap_out(self, src_to_dst: torch.Tensor) -> None:
for i in range(self.num_layers): for i in range(self.num_layers):
self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i], self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
src_to_dst) src_to_dst)

View File

@ -195,15 +195,14 @@ class Worker(WorkerBase):
def cache_swap( def cache_swap(
self, self,
blocks_to_swap_in: Dict[int, int], blocks_to_swap_in: torch.Tensor,
blocks_to_swap_out: Dict[int, int], blocks_to_swap_out: torch.Tensor,
blocks_to_copy: torch.Tensor, blocks_to_copy: torch.Tensor,
) -> None: ) -> None:
# Issue cache operations. # Issue cache operations.
# TODO(woosuk): Profile swapping overhead and optimize if needed. if blocks_to_swap_in.numel() > 0:
if blocks_to_swap_in:
self.cache_engine.swap_in(blocks_to_swap_in) self.cache_engine.swap_in(blocks_to_swap_in)
if blocks_to_swap_out: if blocks_to_swap_out.numel() > 0:
self.cache_engine.swap_out(blocks_to_swap_out) self.cache_engine.swap_out(blocks_to_swap_out)
if blocks_to_copy.numel() > 0: if blocks_to_copy.numel() > 0:
self.cache_engine.copy(blocks_to_copy) self.cache_engine.copy(blocks_to_copy)
@ -219,12 +218,26 @@ class Worker(WorkerBase):
else: else:
seq_group_metadata_list = execute_model_req.seq_group_metadata_list seq_group_metadata_list = execute_model_req.seq_group_metadata_list
blocks_to_swap_in: torch.Tensor
blocks_to_swap_out: torch.Tensor
blocks_to_copy: torch.Tensor
if self.is_driver_worker: if self.is_driver_worker:
assert seq_group_metadata_list is not None assert seq_group_metadata_list is not None
assert execute_model_req is not None assert execute_model_req is not None
num_seq_groups = len(seq_group_metadata_list) num_seq_groups = len(seq_group_metadata_list)
blocks_to_swap_in = execute_model_req.blocks_to_swap_in # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
blocks_to_swap_out = execute_model_req.blocks_to_swap_out # they contain parameters to launch cudamemcpyasync.
blocks_to_swap_in = torch.tensor(
execute_model_req.blocks_to_swap_in,
device="cpu",
dtype=torch.int64).view(-1, 2)
blocks_to_swap_out = torch.tensor(
execute_model_req.blocks_to_swap_out,
device="cpu",
dtype=torch.int64).view(-1, 2)
# `blocks_to_copy` is a gpu tensor. The src and tgt of
# blocks to copy are in the same device, and `blocks_to_copy`
# can be used directly within cuda kernels.
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device=self.device, device=self.device,
dtype=torch.int64).view(-1, 2) dtype=torch.int64).view(-1, 2)