diff --git a/csrc/cache.h b/csrc/cache.h index 10871b36..212a3bf3 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -8,12 +8,12 @@ void swap_blocks( torch::Tensor& src, torch::Tensor& dst, - const std::map& block_mapping); + const torch::Tensor& block_mapping); void copy_blocks( std::vector& key_caches, std::vector& value_caches, - torch::Tensor& block_mapping); + const torch::Tensor& block_mapping); void reshape_and_cache( torch::Tensor& key, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 1e02f7fc..76db96f0 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -23,7 +23,7 @@ void swap_blocks( torch::Tensor& src, torch::Tensor& dst, - const std::map& block_mapping) { + const torch::Tensor& block_mapping) { torch::Device src_device = src.device(); torch::Device dst_device = dst.device(); cudaMemcpyKind memcpy_type; @@ -40,6 +40,11 @@ void swap_blocks( TORCH_CHECK(false, "Invalid device combination"); } + // NOTE(youkaichao): keep in mind that `block_mapping` should be + // a cpu tensor, otherwise every `item` call will require a gpu-cpu + // synchronization. + TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU"); + char *src_ptr = static_cast(src.data_ptr()); char *dst_ptr = static_cast(dst.data_ptr()); @@ -47,9 +52,10 @@ void swap_blocks( const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // NOTE(woosuk): This can be slow if the number of blocks is large. - for (const auto& pair : block_mapping) { - int64_t src_block_number = pair.first; - int64_t dst_block_number = pair.second; + const int64_t num_blocks = block_mapping.size(0); + for (size_t i = 0; i < num_blocks; i++) { + int64_t src_block_number = block_mapping[i][0].item(); + int64_t dst_block_number = block_mapping[i][1].item(); int64_t src_offset = src_block_number * block_size_in_bytes; int64_t dst_offset = dst_block_number * block_size_in_bytes; cudaMemcpyAsync( @@ -97,7 +103,7 @@ __global__ void copy_blocks_kernel( void copy_blocks( std::vector& key_caches, std::vector& value_caches, - torch::Tensor& block_mapping) { + const torch::Tensor& block_mapping) { int num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); if (num_layers == 0) { diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp index 95e3f119..620d11ef 100644 --- a/csrc/cpu/cache.cpp +++ b/csrc/cpu/cache.cpp @@ -83,7 +83,7 @@ void reshape_and_cache_cpu_impl( void copy_blocks(std::vector &key_caches, std::vector &value_caches, - torch::Tensor& block_mapping) { + const torch::Tensor& block_mapping) { int num_layers = key_caches.size(); TORCH_CHECK(num_layers == value_caches.size()); if (num_layers == 0) { @@ -128,6 +128,6 @@ void reshape_and_cache(torch::Tensor &key, torch::Tensor &value, } void swap_blocks(torch::Tensor &src, torch::Tensor &dst, - const std::map &block_mapping) { + const torch::Tensor&block_mapping) { TORCH_CHECK(false, "swap_blocks is unsupported on CPU.") } diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 08d34efb..9db58e07 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -219,7 +219,7 @@ def test_swap(): before_cpu_blocks = block_manager.get_num_free_cpu_blocks() before_gpu_blocks = block_manager.get_num_free_gpu_blocks() mapping = block_manager.swap_out(seq_group) - assert list(mapping.keys()) == gpu_blocks + assert [x[0] for x in mapping] == gpu_blocks after_cpu_blocks = block_manager.get_num_free_cpu_blocks() after_gpu_blocks = block_manager.get_num_free_gpu_blocks() assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks) @@ -232,7 +232,7 @@ def test_swap(): before_cpu_blocks = block_manager.get_num_free_cpu_blocks() before_gpu_blocks = block_manager.get_num_free_gpu_blocks() mapping = block_manager.swap_in(seq_group) - assert list(mapping.keys()) == cpu_blocks + assert [x[0] for x in mapping] == cpu_blocks after_cpu_blocks = block_manager.get_num_free_cpu_blocks() after_gpu_blocks = block_manager.get_num_free_gpu_blocks() assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index 92498c00..3649e6b0 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -355,8 +355,8 @@ def test_swap(): _, out = schedule_and_update_computed_tokens(scheduler) assert len(out.scheduled_seq_groups) == 0 assert out.num_batched_tokens == 0 - assert out.blocks_to_swap_out != {} - assert out.blocks_to_swap_in == {} + assert out.blocks_to_swap_out != [] + assert out.blocks_to_swap_in == [] # Add 1 more task. Swap should be prioritized over new prefill. _, seq_group = create_dummy_prompt("2", prompt_length=60) @@ -365,8 +365,8 @@ def test_swap(): assert len(out.scheduled_seq_groups) == 1 # 3 decodes. It is swapped in. assert out.num_batched_tokens == 30 - assert out.blocks_to_swap_in != {} - assert out.blocks_to_swap_out == {} + assert out.blocks_to_swap_in != [] + assert out.blocks_to_swap_out == [] def test_running_prefill_prioritized_over_swap(): @@ -406,8 +406,8 @@ def test_running_prefill_prioritized_over_swap(): _, out = schedule_and_update_computed_tokens(scheduler) assert len(out.scheduled_seq_groups) == 0 assert out.num_batched_tokens == 0 - assert out.blocks_to_swap_out != {} - assert out.blocks_to_swap_in == {} + assert out.blocks_to_swap_out != [] + assert out.blocks_to_swap_in == [] # Add 1 more task. Swap is not possible, so prefill is running. scheduler.block_manager.can_swap_in = MagicMock() @@ -419,8 +419,8 @@ def test_running_prefill_prioritized_over_swap(): assert len(out.scheduled_seq_groups) == 1 # 3 decodes. It is swapped in. assert out.num_batched_tokens == 30 - assert out.blocks_to_swap_in == {} - assert out.blocks_to_swap_out == {} + assert out.blocks_to_swap_in == [] + assert out.blocks_to_swap_out == [] assert out.scheduled_seq_groups[0].seq_group == seq_group2 # Now although swap is possible, running prefill is prioritized. @@ -429,8 +429,8 @@ def test_running_prefill_prioritized_over_swap(): assert len(out.scheduled_seq_groups) == 1 # 3 decodes. It is swapped in. assert out.num_batched_tokens == 30 - assert out.blocks_to_swap_in == {} - assert out.blocks_to_swap_out == {} + assert out.blocks_to_swap_in == [] + assert out.blocks_to_swap_out == [] assert not seq_group2.is_prefill() assert out.scheduled_seq_groups[0].seq_group == seq_group2 append_new_token(seq_group2, 1) @@ -440,8 +440,8 @@ def test_running_prefill_prioritized_over_swap(): assert len(out.scheduled_seq_groups) == 1 # 3 decodes. It is swapped in. assert out.num_batched_tokens == 1 - assert out.blocks_to_swap_in == {} - assert out.blocks_to_swap_out == {} + assert out.blocks_to_swap_in == [] + assert out.blocks_to_swap_out == [] assert not seq_group2.is_prefill() assert out.scheduled_seq_groups[0].seq_group == seq_group2 append_new_token(seq_group2, 1) @@ -451,8 +451,8 @@ def test_running_prefill_prioritized_over_swap(): _, out = schedule_and_update_computed_tokens(scheduler) assert len(out.scheduled_seq_groups) == 1 assert out.num_batched_tokens == 30 - assert out.blocks_to_swap_in != {} - assert out.blocks_to_swap_out == {} + assert out.blocks_to_swap_in != [] + assert out.blocks_to_swap_out == [] def test_chunked_prefill_preempt(): @@ -493,8 +493,8 @@ def test_chunked_prefill_preempt(): _, out = schedule_and_update_computed_tokens(scheduler) assert len(out.scheduled_seq_groups) == 0 assert out.num_batched_tokens == 0 - assert out.blocks_to_swap_out == {} - assert out.blocks_to_swap_in == {} + assert out.blocks_to_swap_out == [] + assert out.blocks_to_swap_in == [] # Make sure we can reschedule preempted request. _, out = schedule_and_update_computed_tokens(scheduler) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 3f0c918a..6bcabc4f 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -293,8 +293,8 @@ def test_swapped_out_prioritized(): seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) assert len(out.scheduled_seq_groups) == 2 assert out.num_batched_tokens == 2 - assert out.blocks_to_swap_out != {} - assert out.blocks_to_swap_in == {} + assert out.blocks_to_swap_out != [] + assert out.blocks_to_swap_in == [] append_new_token(out, 1) # Add 1 more task. Swap should be prioritized over prefill. @@ -305,8 +305,8 @@ def test_swapped_out_prioritized(): assert len(out.scheduled_seq_groups) == 3 # 3 decodes. It is swapped in. assert out.num_batched_tokens == 3 - assert out.blocks_to_swap_in != {} - assert out.blocks_to_swap_out == {} + assert out.blocks_to_swap_in != [] + assert out.blocks_to_swap_out == [] def initialize_scheduler(*, @@ -566,7 +566,7 @@ def test_decode_schedule_preempted(): # NOTE: When enable_chunk is False, num_seqs budget is not updated. # assert budget.num_curr_seqs == 1 # Both should be preempted, not swapped. - assert output.blocks_to_swap_out == {} + assert output.blocks_to_swap_out == [] # Nothing is copied. assert output.blocks_to_copy == [] @@ -599,7 +599,7 @@ def test_decode_swap_beam_search(): scheduler.block_manager.can_append_slots.side_effect = ( cannot_append_second_group) scheduler.block_manager.swap_out = MagicMock() - expected_swap_mapping = {"5": "7"} + expected_swap_mapping = [("5", "7")] scheduler.block_manager.swap_out.return_value = expected_swap_mapping remainig_running, output = scheduler._schedule_running( @@ -647,7 +647,7 @@ def test_schedule_decode_blocks_to_copy_update(): assert len(output.preempted) == 0 assert len(output.swapped_out) == 0 # Nothing is preempted. - assert output.blocks_to_swap_out == {} + assert output.blocks_to_swap_out == [] # Since append_slot returns the source -> dist mapping, it should # applied. assert output.blocks_to_copy == [(2, 3)] @@ -658,7 +658,7 @@ def test_schedule_swapped_simple(): swapped = deque() policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None - blocks_to_swap_out = {} + blocks_to_swap_out = [] _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) @@ -674,9 +674,9 @@ def test_schedule_swapped_simple(): assert len(output.decode_seq_groups) == 1 assert len(output.prefill_seq_groups) == 0 # swap in is the reverse of swap out - blocks_to_swap_in_reverse = {} - for swapin, swapout in output.blocks_to_swap_in.items(): - blocks_to_swap_in_reverse[swapout] = swapin + blocks_to_swap_in_reverse = [] + for swapin, swapout in output.blocks_to_swap_in: + blocks_to_swap_in_reverse.append((swapout, swapin)) assert blocks_to_swap_out == blocks_to_swap_in_reverse @@ -685,7 +685,7 @@ def test_schedule_swapped_max_token_budget(): swapped = deque() policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None - blocks_to_swap_out = {} + blocks_to_swap_out = [] for _ in range(2): _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) scheduler._allocate_and_set_running(seq_group) @@ -719,7 +719,7 @@ def test_schedule_swapped_max_seqs(): swapped = deque() policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None - blocks_to_swap_out = {} + blocks_to_swap_out = [] for i in range(4): _, seq_group = create_dummy_prompt(str(i), prompt_length=60) scheduler._allocate_and_set_running(seq_group) @@ -752,7 +752,7 @@ def test_schedule_swapped_max_loras(): swapped = deque() policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = set() - blocks_to_swap_out = {} + blocks_to_swap_out = [] for i in range(2): _, seq_group = create_dummy_prompt(str(i), prompt_length=60, @@ -781,7 +781,7 @@ def test_schedule_swapped_cannot_swap_in(): swapped = deque() policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None - blocks_to_swap_out = {} + blocks_to_swap_out = [] for _ in range(2): _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) scheduler._allocate_and_set_running(seq_group) @@ -808,7 +808,7 @@ def test_infeasible_swap(): swapped = deque() policy = PolicyFactory.get_policy(policy_name="fcfs") curr_loras = None - blocks_to_swap_out = {} + blocks_to_swap_out = [] for _ in range(2): _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) scheduler._allocate_and_set_running(seq_group) @@ -839,7 +839,7 @@ def test_schedule_swapped_blocks_to_copy(): _, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2) scheduler._allocate_and_set_running(seq_group) append_new_token_seq_group(60, seq_group, 1) - blocks_to_swap_out = {} + blocks_to_swap_out = [] scheduler._swap_out(seq_group, blocks_to_swap_out) swapped.append(seq_group) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 94a57713..8a27d51b 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -315,7 +315,10 @@ def test_swap_blocks( else: dst_blocks = random.sample(range(num_blocks), num_mappings) - block_mapping = dict(zip(src_blocks, dst_blocks)) + block_mapping = list(zip(src_blocks, dst_blocks)) + block_mapping_tensor = torch.tensor(block_mapping, + dtype=torch.int64, + device="cpu").view(-1, 2) # Create the KV caches on the first device. src_key_caches, src_value_caches = kv_cache_factory( @@ -331,10 +334,12 @@ def test_swap_blocks( src_value_caches_clone = src_value_caches[0].clone() # Call the swap_blocks kernel. - ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping) - ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping) + ops.swap_blocks(src_key_caches[0], dist_key_caches[0], + block_mapping_tensor) + ops.swap_blocks(src_value_caches[0], dist_value_caches[0], + block_mapping_tensor) - for src, dst in block_mapping.items(): + for src, dst in block_mapping: assert torch.allclose(src_key_caches_clone[src].cpu(), dist_key_caches[0][dst].cpu()) assert torch.allclose(src_value_caches_clone[src].cpu(), diff --git a/tests/worker/test_swap.py b/tests/worker/test_swap.py index 4d2d3add..d941ffdb 100644 --- a/tests/worker/test_swap.py +++ b/tests/worker/test_swap.py @@ -54,10 +54,10 @@ def test_swap() -> None: a.cuda(), b.cuda(), rtol=0.0, atol=0.0) # Test swap out. - blocks_to_swap_out = {3: 72, 56: 35, 84: 34} + blocks_to_swap_out = [(3, 72), (56, 35), (84, 34)] execute_model_req = ExecuteModelRequest( seq_group_metadata_list=[], - blocks_to_swap_in={}, + blocks_to_swap_in=[], blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=[], ) @@ -66,24 +66,24 @@ def test_swap() -> None: for i in range(num_layers): gpu_key_cache, gpu_value_cache = gpu_cache[i] cpu_key_cache, cpu_value_cache = cpu_cache[i] - for src, dst in blocks_to_swap_out.items(): + for src, dst in blocks_to_swap_out: assert allclose(gpu_key_cache[src], cpu_key_cache[dst]) assert allclose(gpu_value_cache[src], cpu_value_cache[dst]) # Test swap in. - execute_model_req.blocks_to_swap_out = {} - execute_model_req.blocks_to_swap_in = { - 19: 45, - 67: 23, - 12: 78, - 40: 99, - 1: 71 - } + execute_model_req.blocks_to_swap_out = [] + execute_model_req.blocks_to_swap_in = [ + (19, 45), + (67, 23), + (12, 78), + (40, 99), + (1, 71), + ] worker.execute_model(execute_model_req=execute_model_req) for i in range(num_layers): gpu_key_cache, gpu_value_cache = gpu_cache[i] cpu_key_cache, cpu_value_cache = cpu_cache[i] - for src, dst in execute_model_req.blocks_to_swap_in.items(): + for src, dst in execute_model_req.blocks_to_swap_in: assert allclose(gpu_key_cache[dst], cpu_key_cache[src]) assert allclose(gpu_value_cache[dst], cpu_value_cache[src]) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 02a2fd60..64ccb309 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -39,7 +39,7 @@ class AttentionBackend(ABC): def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], + src_to_dst: torch.Tensor, ) -> None: raise NotImplementedError diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index bee482c3..c2fec915 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -5,7 +5,7 @@ XFormers backend. The duplicated code will be removed once we use flash-attn or flashinfer for all the attention operations. """ from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Type import torch from flash_attn import flash_attn_varlen_func @@ -45,7 +45,7 @@ class FlashAttentionBackend(AttentionBackend): def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], + src_to_dst: torch.Tensor, ) -> None: PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 67b99ba2..8f13f352 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -39,7 +39,7 @@ class FlashInferBackend(AttentionBackend): def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], + src_to_dst: torch.Tensor, ) -> None: raise NotImplementedError diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 10c94f02..8fc1af1a 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -1,6 +1,6 @@ """Attention layer ROCm GPUs.""" from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Type import torch @@ -43,7 +43,7 @@ class ROCmFlashAttentionBackend(AttentionBackend): def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], + src_to_dst: torch.Tensor, ) -> None: PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index c1c07abe..c29218df 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -1,7 +1,7 @@ """ Attention layer with torch scaled_dot_product_attention and PagedAttention.""" from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Type import torch from torch.nn.functional import scaled_dot_product_attention @@ -41,7 +41,7 @@ class TorchSDPABackend(AttentionBackend): def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], + src_to_dst: torch.Tensor, ) -> None: PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 6f7fd51c..3c010b67 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch @@ -196,7 +196,7 @@ class PagedAttention: def swap_blocks( src_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], + src_to_dst: torch.Tensor, ) -> None: src_key_cache = src_kv_cache[0] dst_key_cache = dst_kv_cache[0] diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 4e7392f3..52a170d7 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -473,11 +473,12 @@ class BlockSpaceManagerV1(BlockSpaceManager): def swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int = 0) -> Dict[int, int]: + num_lookahead_slots: int = 0) -> List[Tuple[int, int]]: assert (num_lookahead_slots == 0 ), "BlockSpaceManagerV1 does not support lookahead allocation" # CPU block -> GPU block. + # dict is efficient in lookup `if cpu_block in mapping` mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): new_block_table: BlockTable = [] @@ -500,14 +501,16 @@ class BlockSpaceManagerV1(BlockSpaceManager): cpu_block.block_number: gpu_block.block_number for cpu_block, gpu_block in mapping.items() } - return block_number_mapping + # convert to list of tuples once here + return list(block_number_mapping.items()) def can_swap_out(self, seq_group: SequenceGroup) -> bool: blocks = self._get_physical_blocks(seq_group) return len(blocks) <= self.cpu_allocator.get_num_free_blocks() - def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: # GPU block -> CPU block. + # dict is efficient in lookup `if gpu_block in mapping` mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): new_block_table: BlockTable = [] @@ -530,7 +533,8 @@ class BlockSpaceManagerV1(BlockSpaceManager): gpu_block.block_number: cpu_block.block_number for gpu_block, cpu_block in mapping.items() } - return block_number_mapping + # convert to list of tuples once here + return list(block_number_mapping.items()) def _free_block_table(self, block_table: BlockTable) -> None: # when using a sliding window, each seq will only use up diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 3b483e67..f0bc9656 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -243,13 +243,13 @@ class BlockSpaceManagerV2(BlockSpaceManager): return AllocStatus.LATER def swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> Dict[int, int]: + num_lookahead_slots: int) -> List[Tuple[int, int]]: raise NotImplementedError def can_swap_out(self, seq_group: SequenceGroup) -> bool: return False - def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: raise NotImplementedError def get_num_free_gpu_blocks(self) -> int: diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index ab2c8ea0..b2a5e419 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -1,6 +1,6 @@ import enum from abc import ABC, abstractmethod -from typing import Dict, List +from typing import List from typing import Sequence as GenericSequence from typing import Tuple @@ -69,7 +69,7 @@ class BlockSpaceManager(ABC): @abstractmethod def swap_in(self, seq_group: SequenceGroup, - num_lookahead_slots: int) -> Dict[int, int]: + num_lookahead_slots: int) -> List[Tuple[int, int]]: pass @abstractmethod @@ -77,7 +77,7 @@ class BlockSpaceManager(ABC): pass @abstractmethod - def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: + def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: pass @abstractmethod diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index f426ee95..35e3db18 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -117,10 +117,10 @@ class SchedulerOutputs: num_prefill_groups: int # Total number of batched tokens. num_batched_tokens: int - # Blocks to swap in. Dict of CPU -> GPU block number. - blocks_to_swap_in: Dict[int, int] - # Blocks to swap out. Dict of GPU -> CPU block number. - blocks_to_swap_out: Dict[int, int] + # Blocks to swap in. List of CPU -> GPU block number. + blocks_to_swap_in: List[Tuple[int, int]] + # Blocks to swap out. List of GPU -> CPU block number. + blocks_to_swap_out: List[Tuple[int, int]] # Blocks to copy. Source to dest block. blocks_to_copy: List[Tuple[int, int]] # Sequence groups that are going to be ignored. @@ -174,7 +174,7 @@ class SchedulerRunningOutputs: # Sequences that are swapped out. swapped_out: List[SequenceGroup] # The blocks to swap out. - blocks_to_swap_out: Dict[int, int] + blocks_to_swap_out: List[Tuple[int, int]] # The blocks to copy. blocks_to_copy: List[Tuple[int, int]] # The number of slots for lookahead decoding. @@ -187,7 +187,7 @@ class SchedulerRunningOutputs: prefill_seq_groups=[], preempted=[], swapped_out=[], - blocks_to_swap_out={}, + blocks_to_swap_out=[], blocks_to_copy=[], num_lookahead_slots=0, ) @@ -206,7 +206,7 @@ class SchedulerSwappedInOutputs: # phase. I.e., it means the prefill has been chunked. prefill_seq_groups: List[SequenceGroup] # The blocks to swap in. - blocks_to_swap_in: Dict[int, int] + blocks_to_swap_in: List[Tuple[int, int]] # The blocks to copy. blocks_to_copy: List[Tuple[int, int]] # The number of slots for lookahead decoding. @@ -219,7 +219,7 @@ class SchedulerSwappedInOutputs: return SchedulerSwappedInOutputs( decode_seq_groups=[], prefill_seq_groups=[], - blocks_to_swap_in={}, + blocks_to_swap_in=[], blocks_to_copy=[], num_lookahead_slots=0, infeasible_seq_groups=[], @@ -392,7 +392,7 @@ class Scheduler: scheduling and SchedulerRunningOutputs. """ # Blocks that need to be swapped or copied before model execution. - blocks_to_swap_out: Dict[int, int] = {} + blocks_to_swap_out: List[Tuple[int, int]] = [] blocks_to_copy: List[Tuple[int, int]] = [] decode_seq_groups: List[ScheduledSequenceGroup] = [] @@ -509,7 +509,7 @@ class Scheduler: SchedulerSwappedInOutputs. """ # Blocks that need to be swapped or copied before model execution. - blocks_to_swap_in: Dict[int, int] = {} + blocks_to_swap_in: List[Tuple[int, int]] = [] blocks_to_copy: List[Tuple[int, int]] = [] decode_seq_groups: List[ScheduledSequenceGroup] = [] prefill_seq_groups: List[ScheduledSequenceGroup] = [] @@ -1032,7 +1032,7 @@ class Scheduler: def _preempt( self, seq_group: SequenceGroup, - blocks_to_swap_out: Dict[int, int], + blocks_to_swap_out: List[Tuple[int, int]], preemption_mode: Optional[PreemptionMode] = None, ) -> PreemptionMode: # If preemption mode is not specified, we determine the mode as follows: @@ -1073,24 +1073,24 @@ class Scheduler: def _preempt_by_swap( self, seq_group: SequenceGroup, - blocks_to_swap_out: Dict[int, int], + blocks_to_swap_out: List[Tuple[int, int]], ) -> None: self._swap_out(seq_group, blocks_to_swap_out) def _swap_in( self, seq_group: SequenceGroup, - blocks_to_swap_in: Dict[int, int], + blocks_to_swap_in: List[Tuple[int, int]], ) -> None: mapping = self.block_manager.swap_in(seq_group) - blocks_to_swap_in.update(mapping) + blocks_to_swap_in.extend(mapping) for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): seq.status = SequenceStatus.RUNNING def _swap_out( self, seq_group: SequenceGroup, - blocks_to_swap_out: Dict[int, int], + blocks_to_swap_out: List[Tuple[int, int]], ) -> None: if not self.block_manager.can_swap_out(seq_group): # FIXME(woosuk): Abort the sequence group instead of aborting the @@ -1099,7 +1099,7 @@ class Scheduler: "Aborted due to the lack of CPU swap space. Please increase " "the swap space to avoid this error.") mapping = self.block_manager.swap_out(seq_group) - blocks_to_swap_out.update(mapping) + blocks_to_swap_out.extend(mapping) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): seq.status = SequenceStatus.SWAPPED diff --git a/vllm/sequence.py b/vllm/sequence.py index b486d1fe..42b508b5 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -741,10 +741,10 @@ class ExecuteModelRequest: """The model execution request.""" # The sequence group metadata list. seq_group_metadata_list: List[SequenceGroupMetadata] - # Blocks to swap in. Dict of CPU -> GPU block number. - blocks_to_swap_in: Dict[int, int] = field(default_factory=dict) - # Blocks to swap out. Dict of GPU -> CPU block number. - blocks_to_swap_out: Dict[int, int] = field(default_factory=dict) + # Blocks to swap in. List of CPU -> GPU block number. + blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list) + # Blocks to swap out. List of GPU -> CPU block number. + blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list) # Blocks to copy. Source to dest block. blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list) # The number of slots for lookahead decoding. diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 26a60c65..1fb63a3e 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -1,5 +1,5 @@ """CacheEngine class for managing the KV cache.""" -from typing import Dict, List +from typing import List import torch @@ -67,12 +67,12 @@ class CacheEngine: device=device)) return kv_cache - def swap_in(self, src_to_dst: Dict[int, int]) -> None: + def swap_in(self, src_to_dst: torch.Tensor) -> None: for i in range(self.num_layers): self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i], src_to_dst) - def swap_out(self, src_to_dst: Dict[int, int]) -> None: + def swap_out(self, src_to_dst: torch.Tensor) -> None: for i in range(self.num_layers): self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i], src_to_dst) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 538332ad..313bcf25 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -195,15 +195,14 @@ class Worker(WorkerBase): def cache_swap( self, - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], + blocks_to_swap_in: torch.Tensor, + blocks_to_swap_out: torch.Tensor, blocks_to_copy: torch.Tensor, ) -> None: # Issue cache operations. - # TODO(woosuk): Profile swapping overhead and optimize if needed. - if blocks_to_swap_in: + if blocks_to_swap_in.numel() > 0: self.cache_engine.swap_in(blocks_to_swap_in) - if blocks_to_swap_out: + if blocks_to_swap_out.numel() > 0: self.cache_engine.swap_out(blocks_to_swap_out) if blocks_to_copy.numel() > 0: self.cache_engine.copy(blocks_to_copy) @@ -219,12 +218,26 @@ class Worker(WorkerBase): else: seq_group_metadata_list = execute_model_req.seq_group_metadata_list + blocks_to_swap_in: torch.Tensor + blocks_to_swap_out: torch.Tensor + blocks_to_copy: torch.Tensor if self.is_driver_worker: assert seq_group_metadata_list is not None assert execute_model_req is not None num_seq_groups = len(seq_group_metadata_list) - blocks_to_swap_in = execute_model_req.blocks_to_swap_in - blocks_to_swap_out = execute_model_req.blocks_to_swap_out + # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. + # they contain parameters to launch cudamemcpyasync. + blocks_to_swap_in = torch.tensor( + execute_model_req.blocks_to_swap_in, + device="cpu", + dtype=torch.int64).view(-1, 2) + blocks_to_swap_out = torch.tensor( + execute_model_req.blocks_to_swap_out, + device="cpu", + dtype=torch.int64).view(-1, 2) + # `blocks_to_copy` is a gpu tensor. The src and tgt of + # blocks to copy are in the same device, and `blocks_to_copy` + # can be used directly within cuda kernels. blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, device=self.device, dtype=torch.int64).view(-1, 2)