[Core][Optimization] change python dict to pytorch tensor for blocks to swap (#4659)
This commit is contained in:
parent
ad932a221d
commit
20cfcdec99
@ -8,12 +8,12 @@
|
|||||||
void swap_blocks(
|
void swap_blocks(
|
||||||
torch::Tensor& src,
|
torch::Tensor& src,
|
||||||
torch::Tensor& dst,
|
torch::Tensor& dst,
|
||||||
const std::map<int64_t, int64_t>& block_mapping);
|
const torch::Tensor& block_mapping);
|
||||||
|
|
||||||
void copy_blocks(
|
void copy_blocks(
|
||||||
std::vector<torch::Tensor>& key_caches,
|
std::vector<torch::Tensor>& key_caches,
|
||||||
std::vector<torch::Tensor>& value_caches,
|
std::vector<torch::Tensor>& value_caches,
|
||||||
torch::Tensor& block_mapping);
|
const torch::Tensor& block_mapping);
|
||||||
|
|
||||||
void reshape_and_cache(
|
void reshape_and_cache(
|
||||||
torch::Tensor& key,
|
torch::Tensor& key,
|
||||||
|
@ -23,7 +23,7 @@
|
|||||||
void swap_blocks(
|
void swap_blocks(
|
||||||
torch::Tensor& src,
|
torch::Tensor& src,
|
||||||
torch::Tensor& dst,
|
torch::Tensor& dst,
|
||||||
const std::map<int64_t, int64_t>& block_mapping) {
|
const torch::Tensor& block_mapping) {
|
||||||
torch::Device src_device = src.device();
|
torch::Device src_device = src.device();
|
||||||
torch::Device dst_device = dst.device();
|
torch::Device dst_device = dst.device();
|
||||||
cudaMemcpyKind memcpy_type;
|
cudaMemcpyKind memcpy_type;
|
||||||
@ -40,6 +40,11 @@ void swap_blocks(
|
|||||||
TORCH_CHECK(false, "Invalid device combination");
|
TORCH_CHECK(false, "Invalid device combination");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOTE(youkaichao): keep in mind that `block_mapping` should be
|
||||||
|
// a cpu tensor, otherwise every `item` call will require a gpu-cpu
|
||||||
|
// synchronization.
|
||||||
|
TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU");
|
||||||
|
|
||||||
char *src_ptr = static_cast<char*>(src.data_ptr());
|
char *src_ptr = static_cast<char*>(src.data_ptr());
|
||||||
char *dst_ptr = static_cast<char*>(dst.data_ptr());
|
char *dst_ptr = static_cast<char*>(dst.data_ptr());
|
||||||
|
|
||||||
@ -47,9 +52,10 @@ void swap_blocks(
|
|||||||
const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
|
const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
// NOTE(woosuk): This can be slow if the number of blocks is large.
|
// NOTE(woosuk): This can be slow if the number of blocks is large.
|
||||||
for (const auto& pair : block_mapping) {
|
const int64_t num_blocks = block_mapping.size(0);
|
||||||
int64_t src_block_number = pair.first;
|
for (size_t i = 0; i < num_blocks; i++) {
|
||||||
int64_t dst_block_number = pair.second;
|
int64_t src_block_number = block_mapping[i][0].item<int64_t>();
|
||||||
|
int64_t dst_block_number = block_mapping[i][1].item<int64_t>();
|
||||||
int64_t src_offset = src_block_number * block_size_in_bytes;
|
int64_t src_offset = src_block_number * block_size_in_bytes;
|
||||||
int64_t dst_offset = dst_block_number * block_size_in_bytes;
|
int64_t dst_offset = dst_block_number * block_size_in_bytes;
|
||||||
cudaMemcpyAsync(
|
cudaMemcpyAsync(
|
||||||
@ -97,7 +103,7 @@ __global__ void copy_blocks_kernel(
|
|||||||
void copy_blocks(
|
void copy_blocks(
|
||||||
std::vector<torch::Tensor>& key_caches,
|
std::vector<torch::Tensor>& key_caches,
|
||||||
std::vector<torch::Tensor>& value_caches,
|
std::vector<torch::Tensor>& value_caches,
|
||||||
torch::Tensor& block_mapping) {
|
const torch::Tensor& block_mapping) {
|
||||||
int num_layers = key_caches.size();
|
int num_layers = key_caches.size();
|
||||||
TORCH_CHECK(num_layers == value_caches.size());
|
TORCH_CHECK(num_layers == value_caches.size());
|
||||||
if (num_layers == 0) {
|
if (num_layers == 0) {
|
||||||
|
@ -83,7 +83,7 @@ void reshape_and_cache_cpu_impl(
|
|||||||
|
|
||||||
void copy_blocks(std::vector<torch::Tensor> &key_caches,
|
void copy_blocks(std::vector<torch::Tensor> &key_caches,
|
||||||
std::vector<torch::Tensor> &value_caches,
|
std::vector<torch::Tensor> &value_caches,
|
||||||
torch::Tensor& block_mapping) {
|
const torch::Tensor& block_mapping) {
|
||||||
int num_layers = key_caches.size();
|
int num_layers = key_caches.size();
|
||||||
TORCH_CHECK(num_layers == value_caches.size());
|
TORCH_CHECK(num_layers == value_caches.size());
|
||||||
if (num_layers == 0) {
|
if (num_layers == 0) {
|
||||||
@ -128,6 +128,6 @@ void reshape_and_cache(torch::Tensor &key, torch::Tensor &value,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void swap_blocks(torch::Tensor &src, torch::Tensor &dst,
|
void swap_blocks(torch::Tensor &src, torch::Tensor &dst,
|
||||||
const std::map<int64_t, int64_t> &block_mapping) {
|
const torch::Tensor&block_mapping) {
|
||||||
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
|
TORCH_CHECK(false, "swap_blocks is unsupported on CPU.")
|
||||||
}
|
}
|
||||||
|
@ -219,7 +219,7 @@ def test_swap():
|
|||||||
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||||
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||||
mapping = block_manager.swap_out(seq_group)
|
mapping = block_manager.swap_out(seq_group)
|
||||||
assert list(mapping.keys()) == gpu_blocks
|
assert [x[0] for x in mapping] == gpu_blocks
|
||||||
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||||
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||||
assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks)
|
assert before_cpu_blocks == after_cpu_blocks + len(gpu_blocks)
|
||||||
@ -232,7 +232,7 @@ def test_swap():
|
|||||||
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
before_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||||
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
before_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||||
mapping = block_manager.swap_in(seq_group)
|
mapping = block_manager.swap_in(seq_group)
|
||||||
assert list(mapping.keys()) == cpu_blocks
|
assert [x[0] for x in mapping] == cpu_blocks
|
||||||
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
after_cpu_blocks = block_manager.get_num_free_cpu_blocks()
|
||||||
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
after_gpu_blocks = block_manager.get_num_free_gpu_blocks()
|
||||||
assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks
|
assert before_cpu_blocks + len(cpu_blocks) == after_cpu_blocks
|
||||||
|
@ -355,8 +355,8 @@ def test_swap():
|
|||||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||||
assert len(out.scheduled_seq_groups) == 0
|
assert len(out.scheduled_seq_groups) == 0
|
||||||
assert out.num_batched_tokens == 0
|
assert out.num_batched_tokens == 0
|
||||||
assert out.blocks_to_swap_out != {}
|
assert out.blocks_to_swap_out != []
|
||||||
assert out.blocks_to_swap_in == {}
|
assert out.blocks_to_swap_in == []
|
||||||
|
|
||||||
# Add 1 more task. Swap should be prioritized over new prefill.
|
# Add 1 more task. Swap should be prioritized over new prefill.
|
||||||
_, seq_group = create_dummy_prompt("2", prompt_length=60)
|
_, seq_group = create_dummy_prompt("2", prompt_length=60)
|
||||||
@ -365,8 +365,8 @@ def test_swap():
|
|||||||
assert len(out.scheduled_seq_groups) == 1
|
assert len(out.scheduled_seq_groups) == 1
|
||||||
# 3 decodes. It is swapped in.
|
# 3 decodes. It is swapped in.
|
||||||
assert out.num_batched_tokens == 30
|
assert out.num_batched_tokens == 30
|
||||||
assert out.blocks_to_swap_in != {}
|
assert out.blocks_to_swap_in != []
|
||||||
assert out.blocks_to_swap_out == {}
|
assert out.blocks_to_swap_out == []
|
||||||
|
|
||||||
|
|
||||||
def test_running_prefill_prioritized_over_swap():
|
def test_running_prefill_prioritized_over_swap():
|
||||||
@ -406,8 +406,8 @@ def test_running_prefill_prioritized_over_swap():
|
|||||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||||
assert len(out.scheduled_seq_groups) == 0
|
assert len(out.scheduled_seq_groups) == 0
|
||||||
assert out.num_batched_tokens == 0
|
assert out.num_batched_tokens == 0
|
||||||
assert out.blocks_to_swap_out != {}
|
assert out.blocks_to_swap_out != []
|
||||||
assert out.blocks_to_swap_in == {}
|
assert out.blocks_to_swap_in == []
|
||||||
|
|
||||||
# Add 1 more task. Swap is not possible, so prefill is running.
|
# Add 1 more task. Swap is not possible, so prefill is running.
|
||||||
scheduler.block_manager.can_swap_in = MagicMock()
|
scheduler.block_manager.can_swap_in = MagicMock()
|
||||||
@ -419,8 +419,8 @@ def test_running_prefill_prioritized_over_swap():
|
|||||||
assert len(out.scheduled_seq_groups) == 1
|
assert len(out.scheduled_seq_groups) == 1
|
||||||
# 3 decodes. It is swapped in.
|
# 3 decodes. It is swapped in.
|
||||||
assert out.num_batched_tokens == 30
|
assert out.num_batched_tokens == 30
|
||||||
assert out.blocks_to_swap_in == {}
|
assert out.blocks_to_swap_in == []
|
||||||
assert out.blocks_to_swap_out == {}
|
assert out.blocks_to_swap_out == []
|
||||||
assert out.scheduled_seq_groups[0].seq_group == seq_group2
|
assert out.scheduled_seq_groups[0].seq_group == seq_group2
|
||||||
|
|
||||||
# Now although swap is possible, running prefill is prioritized.
|
# Now although swap is possible, running prefill is prioritized.
|
||||||
@ -429,8 +429,8 @@ def test_running_prefill_prioritized_over_swap():
|
|||||||
assert len(out.scheduled_seq_groups) == 1
|
assert len(out.scheduled_seq_groups) == 1
|
||||||
# 3 decodes. It is swapped in.
|
# 3 decodes. It is swapped in.
|
||||||
assert out.num_batched_tokens == 30
|
assert out.num_batched_tokens == 30
|
||||||
assert out.blocks_to_swap_in == {}
|
assert out.blocks_to_swap_in == []
|
||||||
assert out.blocks_to_swap_out == {}
|
assert out.blocks_to_swap_out == []
|
||||||
assert not seq_group2.is_prefill()
|
assert not seq_group2.is_prefill()
|
||||||
assert out.scheduled_seq_groups[0].seq_group == seq_group2
|
assert out.scheduled_seq_groups[0].seq_group == seq_group2
|
||||||
append_new_token(seq_group2, 1)
|
append_new_token(seq_group2, 1)
|
||||||
@ -440,8 +440,8 @@ def test_running_prefill_prioritized_over_swap():
|
|||||||
assert len(out.scheduled_seq_groups) == 1
|
assert len(out.scheduled_seq_groups) == 1
|
||||||
# 3 decodes. It is swapped in.
|
# 3 decodes. It is swapped in.
|
||||||
assert out.num_batched_tokens == 1
|
assert out.num_batched_tokens == 1
|
||||||
assert out.blocks_to_swap_in == {}
|
assert out.blocks_to_swap_in == []
|
||||||
assert out.blocks_to_swap_out == {}
|
assert out.blocks_to_swap_out == []
|
||||||
assert not seq_group2.is_prefill()
|
assert not seq_group2.is_prefill()
|
||||||
assert out.scheduled_seq_groups[0].seq_group == seq_group2
|
assert out.scheduled_seq_groups[0].seq_group == seq_group2
|
||||||
append_new_token(seq_group2, 1)
|
append_new_token(seq_group2, 1)
|
||||||
@ -451,8 +451,8 @@ def test_running_prefill_prioritized_over_swap():
|
|||||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||||
assert len(out.scheduled_seq_groups) == 1
|
assert len(out.scheduled_seq_groups) == 1
|
||||||
assert out.num_batched_tokens == 30
|
assert out.num_batched_tokens == 30
|
||||||
assert out.blocks_to_swap_in != {}
|
assert out.blocks_to_swap_in != []
|
||||||
assert out.blocks_to_swap_out == {}
|
assert out.blocks_to_swap_out == []
|
||||||
|
|
||||||
|
|
||||||
def test_chunked_prefill_preempt():
|
def test_chunked_prefill_preempt():
|
||||||
@ -493,8 +493,8 @@ def test_chunked_prefill_preempt():
|
|||||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||||
assert len(out.scheduled_seq_groups) == 0
|
assert len(out.scheduled_seq_groups) == 0
|
||||||
assert out.num_batched_tokens == 0
|
assert out.num_batched_tokens == 0
|
||||||
assert out.blocks_to_swap_out == {}
|
assert out.blocks_to_swap_out == []
|
||||||
assert out.blocks_to_swap_in == {}
|
assert out.blocks_to_swap_in == []
|
||||||
|
|
||||||
# Make sure we can reschedule preempted request.
|
# Make sure we can reschedule preempted request.
|
||||||
_, out = schedule_and_update_computed_tokens(scheduler)
|
_, out = schedule_and_update_computed_tokens(scheduler)
|
||||||
|
@ -293,8 +293,8 @@ def test_swapped_out_prioritized():
|
|||||||
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
|
||||||
assert len(out.scheduled_seq_groups) == 2
|
assert len(out.scheduled_seq_groups) == 2
|
||||||
assert out.num_batched_tokens == 2
|
assert out.num_batched_tokens == 2
|
||||||
assert out.blocks_to_swap_out != {}
|
assert out.blocks_to_swap_out != []
|
||||||
assert out.blocks_to_swap_in == {}
|
assert out.blocks_to_swap_in == []
|
||||||
append_new_token(out, 1)
|
append_new_token(out, 1)
|
||||||
|
|
||||||
# Add 1 more task. Swap should be prioritized over prefill.
|
# Add 1 more task. Swap should be prioritized over prefill.
|
||||||
@ -305,8 +305,8 @@ def test_swapped_out_prioritized():
|
|||||||
assert len(out.scheduled_seq_groups) == 3
|
assert len(out.scheduled_seq_groups) == 3
|
||||||
# 3 decodes. It is swapped in.
|
# 3 decodes. It is swapped in.
|
||||||
assert out.num_batched_tokens == 3
|
assert out.num_batched_tokens == 3
|
||||||
assert out.blocks_to_swap_in != {}
|
assert out.blocks_to_swap_in != []
|
||||||
assert out.blocks_to_swap_out == {}
|
assert out.blocks_to_swap_out == []
|
||||||
|
|
||||||
|
|
||||||
def initialize_scheduler(*,
|
def initialize_scheduler(*,
|
||||||
@ -566,7 +566,7 @@ def test_decode_schedule_preempted():
|
|||||||
# NOTE: When enable_chunk is False, num_seqs budget is not updated.
|
# NOTE: When enable_chunk is False, num_seqs budget is not updated.
|
||||||
# assert budget.num_curr_seqs == 1
|
# assert budget.num_curr_seqs == 1
|
||||||
# Both should be preempted, not swapped.
|
# Both should be preempted, not swapped.
|
||||||
assert output.blocks_to_swap_out == {}
|
assert output.blocks_to_swap_out == []
|
||||||
# Nothing is copied.
|
# Nothing is copied.
|
||||||
assert output.blocks_to_copy == []
|
assert output.blocks_to_copy == []
|
||||||
|
|
||||||
@ -599,7 +599,7 @@ def test_decode_swap_beam_search():
|
|||||||
scheduler.block_manager.can_append_slots.side_effect = (
|
scheduler.block_manager.can_append_slots.side_effect = (
|
||||||
cannot_append_second_group)
|
cannot_append_second_group)
|
||||||
scheduler.block_manager.swap_out = MagicMock()
|
scheduler.block_manager.swap_out = MagicMock()
|
||||||
expected_swap_mapping = {"5": "7"}
|
expected_swap_mapping = [("5", "7")]
|
||||||
scheduler.block_manager.swap_out.return_value = expected_swap_mapping
|
scheduler.block_manager.swap_out.return_value = expected_swap_mapping
|
||||||
|
|
||||||
remainig_running, output = scheduler._schedule_running(
|
remainig_running, output = scheduler._schedule_running(
|
||||||
@ -647,7 +647,7 @@ def test_schedule_decode_blocks_to_copy_update():
|
|||||||
assert len(output.preempted) == 0
|
assert len(output.preempted) == 0
|
||||||
assert len(output.swapped_out) == 0
|
assert len(output.swapped_out) == 0
|
||||||
# Nothing is preempted.
|
# Nothing is preempted.
|
||||||
assert output.blocks_to_swap_out == {}
|
assert output.blocks_to_swap_out == []
|
||||||
# Since append_slot returns the source -> dist mapping, it should
|
# Since append_slot returns the source -> dist mapping, it should
|
||||||
# applied.
|
# applied.
|
||||||
assert output.blocks_to_copy == [(2, 3)]
|
assert output.blocks_to_copy == [(2, 3)]
|
||||||
@ -658,7 +658,7 @@ def test_schedule_swapped_simple():
|
|||||||
swapped = deque()
|
swapped = deque()
|
||||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||||
curr_loras = None
|
curr_loras = None
|
||||||
blocks_to_swap_out = {}
|
blocks_to_swap_out = []
|
||||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||||
scheduler._allocate_and_set_running(seq_group)
|
scheduler._allocate_and_set_running(seq_group)
|
||||||
append_new_token_seq_group(60, seq_group, 1)
|
append_new_token_seq_group(60, seq_group, 1)
|
||||||
@ -674,9 +674,9 @@ def test_schedule_swapped_simple():
|
|||||||
assert len(output.decode_seq_groups) == 1
|
assert len(output.decode_seq_groups) == 1
|
||||||
assert len(output.prefill_seq_groups) == 0
|
assert len(output.prefill_seq_groups) == 0
|
||||||
# swap in is the reverse of swap out
|
# swap in is the reverse of swap out
|
||||||
blocks_to_swap_in_reverse = {}
|
blocks_to_swap_in_reverse = []
|
||||||
for swapin, swapout in output.blocks_to_swap_in.items():
|
for swapin, swapout in output.blocks_to_swap_in:
|
||||||
blocks_to_swap_in_reverse[swapout] = swapin
|
blocks_to_swap_in_reverse.append((swapout, swapin))
|
||||||
assert blocks_to_swap_out == blocks_to_swap_in_reverse
|
assert blocks_to_swap_out == blocks_to_swap_in_reverse
|
||||||
|
|
||||||
|
|
||||||
@ -685,7 +685,7 @@ def test_schedule_swapped_max_token_budget():
|
|||||||
swapped = deque()
|
swapped = deque()
|
||||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||||
curr_loras = None
|
curr_loras = None
|
||||||
blocks_to_swap_out = {}
|
blocks_to_swap_out = []
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||||
scheduler._allocate_and_set_running(seq_group)
|
scheduler._allocate_and_set_running(seq_group)
|
||||||
@ -719,7 +719,7 @@ def test_schedule_swapped_max_seqs():
|
|||||||
swapped = deque()
|
swapped = deque()
|
||||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||||
curr_loras = None
|
curr_loras = None
|
||||||
blocks_to_swap_out = {}
|
blocks_to_swap_out = []
|
||||||
for i in range(4):
|
for i in range(4):
|
||||||
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
_, seq_group = create_dummy_prompt(str(i), prompt_length=60)
|
||||||
scheduler._allocate_and_set_running(seq_group)
|
scheduler._allocate_and_set_running(seq_group)
|
||||||
@ -752,7 +752,7 @@ def test_schedule_swapped_max_loras():
|
|||||||
swapped = deque()
|
swapped = deque()
|
||||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||||
curr_loras = set()
|
curr_loras = set()
|
||||||
blocks_to_swap_out = {}
|
blocks_to_swap_out = []
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
_, seq_group = create_dummy_prompt(str(i),
|
_, seq_group = create_dummy_prompt(str(i),
|
||||||
prompt_length=60,
|
prompt_length=60,
|
||||||
@ -781,7 +781,7 @@ def test_schedule_swapped_cannot_swap_in():
|
|||||||
swapped = deque()
|
swapped = deque()
|
||||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||||
curr_loras = None
|
curr_loras = None
|
||||||
blocks_to_swap_out = {}
|
blocks_to_swap_out = []
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||||
scheduler._allocate_and_set_running(seq_group)
|
scheduler._allocate_and_set_running(seq_group)
|
||||||
@ -808,7 +808,7 @@ def test_infeasible_swap():
|
|||||||
swapped = deque()
|
swapped = deque()
|
||||||
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||||
curr_loras = None
|
curr_loras = None
|
||||||
blocks_to_swap_out = {}
|
blocks_to_swap_out = []
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||||
scheduler._allocate_and_set_running(seq_group)
|
scheduler._allocate_and_set_running(seq_group)
|
||||||
@ -839,7 +839,7 @@ def test_schedule_swapped_blocks_to_copy():
|
|||||||
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
_, seq_group = create_dummy_prompt("1", prompt_length=60, best_of=2)
|
||||||
scheduler._allocate_and_set_running(seq_group)
|
scheduler._allocate_and_set_running(seq_group)
|
||||||
append_new_token_seq_group(60, seq_group, 1)
|
append_new_token_seq_group(60, seq_group, 1)
|
||||||
blocks_to_swap_out = {}
|
blocks_to_swap_out = []
|
||||||
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
scheduler._swap_out(seq_group, blocks_to_swap_out)
|
||||||
swapped.append(seq_group)
|
swapped.append(seq_group)
|
||||||
|
|
||||||
|
@ -315,7 +315,10 @@ def test_swap_blocks(
|
|||||||
else:
|
else:
|
||||||
dst_blocks = random.sample(range(num_blocks), num_mappings)
|
dst_blocks = random.sample(range(num_blocks), num_mappings)
|
||||||
|
|
||||||
block_mapping = dict(zip(src_blocks, dst_blocks))
|
block_mapping = list(zip(src_blocks, dst_blocks))
|
||||||
|
block_mapping_tensor = torch.tensor(block_mapping,
|
||||||
|
dtype=torch.int64,
|
||||||
|
device="cpu").view(-1, 2)
|
||||||
|
|
||||||
# Create the KV caches on the first device.
|
# Create the KV caches on the first device.
|
||||||
src_key_caches, src_value_caches = kv_cache_factory(
|
src_key_caches, src_value_caches = kv_cache_factory(
|
||||||
@ -331,10 +334,12 @@ def test_swap_blocks(
|
|||||||
src_value_caches_clone = src_value_caches[0].clone()
|
src_value_caches_clone = src_value_caches[0].clone()
|
||||||
|
|
||||||
# Call the swap_blocks kernel.
|
# Call the swap_blocks kernel.
|
||||||
ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping)
|
ops.swap_blocks(src_key_caches[0], dist_key_caches[0],
|
||||||
ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping)
|
block_mapping_tensor)
|
||||||
|
ops.swap_blocks(src_value_caches[0], dist_value_caches[0],
|
||||||
|
block_mapping_tensor)
|
||||||
|
|
||||||
for src, dst in block_mapping.items():
|
for src, dst in block_mapping:
|
||||||
assert torch.allclose(src_key_caches_clone[src].cpu(),
|
assert torch.allclose(src_key_caches_clone[src].cpu(),
|
||||||
dist_key_caches[0][dst].cpu())
|
dist_key_caches[0][dst].cpu())
|
||||||
assert torch.allclose(src_value_caches_clone[src].cpu(),
|
assert torch.allclose(src_value_caches_clone[src].cpu(),
|
||||||
|
@ -54,10 +54,10 @@ def test_swap() -> None:
|
|||||||
a.cuda(), b.cuda(), rtol=0.0, atol=0.0)
|
a.cuda(), b.cuda(), rtol=0.0, atol=0.0)
|
||||||
|
|
||||||
# Test swap out.
|
# Test swap out.
|
||||||
blocks_to_swap_out = {3: 72, 56: 35, 84: 34}
|
blocks_to_swap_out = [(3, 72), (56, 35), (84, 34)]
|
||||||
execute_model_req = ExecuteModelRequest(
|
execute_model_req = ExecuteModelRequest(
|
||||||
seq_group_metadata_list=[],
|
seq_group_metadata_list=[],
|
||||||
blocks_to_swap_in={},
|
blocks_to_swap_in=[],
|
||||||
blocks_to_swap_out=blocks_to_swap_out,
|
blocks_to_swap_out=blocks_to_swap_out,
|
||||||
blocks_to_copy=[],
|
blocks_to_copy=[],
|
||||||
)
|
)
|
||||||
@ -66,24 +66,24 @@ def test_swap() -> None:
|
|||||||
for i in range(num_layers):
|
for i in range(num_layers):
|
||||||
gpu_key_cache, gpu_value_cache = gpu_cache[i]
|
gpu_key_cache, gpu_value_cache = gpu_cache[i]
|
||||||
cpu_key_cache, cpu_value_cache = cpu_cache[i]
|
cpu_key_cache, cpu_value_cache = cpu_cache[i]
|
||||||
for src, dst in blocks_to_swap_out.items():
|
for src, dst in blocks_to_swap_out:
|
||||||
assert allclose(gpu_key_cache[src], cpu_key_cache[dst])
|
assert allclose(gpu_key_cache[src], cpu_key_cache[dst])
|
||||||
assert allclose(gpu_value_cache[src], cpu_value_cache[dst])
|
assert allclose(gpu_value_cache[src], cpu_value_cache[dst])
|
||||||
|
|
||||||
# Test swap in.
|
# Test swap in.
|
||||||
execute_model_req.blocks_to_swap_out = {}
|
execute_model_req.blocks_to_swap_out = []
|
||||||
execute_model_req.blocks_to_swap_in = {
|
execute_model_req.blocks_to_swap_in = [
|
||||||
19: 45,
|
(19, 45),
|
||||||
67: 23,
|
(67, 23),
|
||||||
12: 78,
|
(12, 78),
|
||||||
40: 99,
|
(40, 99),
|
||||||
1: 71
|
(1, 71),
|
||||||
}
|
]
|
||||||
worker.execute_model(execute_model_req=execute_model_req)
|
worker.execute_model(execute_model_req=execute_model_req)
|
||||||
|
|
||||||
for i in range(num_layers):
|
for i in range(num_layers):
|
||||||
gpu_key_cache, gpu_value_cache = gpu_cache[i]
|
gpu_key_cache, gpu_value_cache = gpu_cache[i]
|
||||||
cpu_key_cache, cpu_value_cache = cpu_cache[i]
|
cpu_key_cache, cpu_value_cache = cpu_cache[i]
|
||||||
for src, dst in execute_model_req.blocks_to_swap_in.items():
|
for src, dst in execute_model_req.blocks_to_swap_in:
|
||||||
assert allclose(gpu_key_cache[dst], cpu_key_cache[src])
|
assert allclose(gpu_key_cache[dst], cpu_key_cache[src])
|
||||||
assert allclose(gpu_value_cache[dst], cpu_value_cache[src])
|
assert allclose(gpu_value_cache[dst], cpu_value_cache[src])
|
||||||
|
@ -39,7 +39,7 @@ class AttentionBackend(ABC):
|
|||||||
def swap_blocks(
|
def swap_blocks(
|
||||||
src_kv_cache: torch.Tensor,
|
src_kv_cache: torch.Tensor,
|
||||||
dst_kv_cache: torch.Tensor,
|
dst_kv_cache: torch.Tensor,
|
||||||
src_to_dst: Dict[int, int],
|
src_to_dst: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ XFormers backend. The duplicated code will be removed once we use flash-attn or
|
|||||||
flashinfer for all the attention operations.
|
flashinfer for all the attention operations.
|
||||||
"""
|
"""
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple, Type
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from flash_attn import flash_attn_varlen_func
|
from flash_attn import flash_attn_varlen_func
|
||||||
@ -45,7 +45,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
def swap_blocks(
|
def swap_blocks(
|
||||||
src_kv_cache: torch.Tensor,
|
src_kv_cache: torch.Tensor,
|
||||||
dst_kv_cache: torch.Tensor,
|
dst_kv_cache: torch.Tensor,
|
||||||
src_to_dst: Dict[int, int],
|
src_to_dst: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ class FlashInferBackend(AttentionBackend):
|
|||||||
def swap_blocks(
|
def swap_blocks(
|
||||||
src_kv_cache: torch.Tensor,
|
src_kv_cache: torch.Tensor,
|
||||||
dst_kv_cache: torch.Tensor,
|
dst_kv_cache: torch.Tensor,
|
||||||
src_to_dst: Dict[int, int],
|
src_to_dst: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Attention layer ROCm GPUs."""
|
"""Attention layer ROCm GPUs."""
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple, Type
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -43,7 +43,7 @@ class ROCmFlashAttentionBackend(AttentionBackend):
|
|||||||
def swap_blocks(
|
def swap_blocks(
|
||||||
src_kv_cache: torch.Tensor,
|
src_kv_cache: torch.Tensor,
|
||||||
dst_kv_cache: torch.Tensor,
|
dst_kv_cache: torch.Tensor,
|
||||||
src_to_dst: Dict[int, int],
|
src_to_dst: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
""" Attention layer with torch scaled_dot_product_attention
|
""" Attention layer with torch scaled_dot_product_attention
|
||||||
and PagedAttention."""
|
and PagedAttention."""
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple, Type
|
from typing import List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.functional import scaled_dot_product_attention
|
from torch.nn.functional import scaled_dot_product_attention
|
||||||
@ -41,7 +41,7 @@ class TorchSDPABackend(AttentionBackend):
|
|||||||
def swap_blocks(
|
def swap_blocks(
|
||||||
src_kv_cache: torch.Tensor,
|
src_kv_cache: torch.Tensor,
|
||||||
dst_kv_cache: torch.Tensor,
|
dst_kv_cache: torch.Tensor,
|
||||||
src_to_dst: Dict[int, int],
|
src_to_dst: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -196,7 +196,7 @@ class PagedAttention:
|
|||||||
def swap_blocks(
|
def swap_blocks(
|
||||||
src_kv_cache: torch.Tensor,
|
src_kv_cache: torch.Tensor,
|
||||||
dst_kv_cache: torch.Tensor,
|
dst_kv_cache: torch.Tensor,
|
||||||
src_to_dst: Dict[int, int],
|
src_to_dst: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
src_key_cache = src_kv_cache[0]
|
src_key_cache = src_kv_cache[0]
|
||||||
dst_key_cache = dst_kv_cache[0]
|
dst_key_cache = dst_kv_cache[0]
|
||||||
|
@ -473,11 +473,12 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
|||||||
|
|
||||||
def swap_in(self,
|
def swap_in(self,
|
||||||
seq_group: SequenceGroup,
|
seq_group: SequenceGroup,
|
||||||
num_lookahead_slots: int = 0) -> Dict[int, int]:
|
num_lookahead_slots: int = 0) -> List[Tuple[int, int]]:
|
||||||
assert (num_lookahead_slots == 0
|
assert (num_lookahead_slots == 0
|
||||||
), "BlockSpaceManagerV1 does not support lookahead allocation"
|
), "BlockSpaceManagerV1 does not support lookahead allocation"
|
||||||
|
|
||||||
# CPU block -> GPU block.
|
# CPU block -> GPU block.
|
||||||
|
# dict is efficient in lookup `if cpu_block in mapping`
|
||||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
||||||
new_block_table: BlockTable = []
|
new_block_table: BlockTable = []
|
||||||
@ -500,14 +501,16 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
|||||||
cpu_block.block_number: gpu_block.block_number
|
cpu_block.block_number: gpu_block.block_number
|
||||||
for cpu_block, gpu_block in mapping.items()
|
for cpu_block, gpu_block in mapping.items()
|
||||||
}
|
}
|
||||||
return block_number_mapping
|
# convert to list of tuples once here
|
||||||
|
return list(block_number_mapping.items())
|
||||||
|
|
||||||
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
|
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
|
||||||
blocks = self._get_physical_blocks(seq_group)
|
blocks = self._get_physical_blocks(seq_group)
|
||||||
return len(blocks) <= self.cpu_allocator.get_num_free_blocks()
|
return len(blocks) <= self.cpu_allocator.get_num_free_blocks()
|
||||||
|
|
||||||
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
|
||||||
# GPU block -> CPU block.
|
# GPU block -> CPU block.
|
||||||
|
# dict is efficient in lookup `if gpu_block in mapping`
|
||||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||||
new_block_table: BlockTable = []
|
new_block_table: BlockTable = []
|
||||||
@ -530,7 +533,8 @@ class BlockSpaceManagerV1(BlockSpaceManager):
|
|||||||
gpu_block.block_number: cpu_block.block_number
|
gpu_block.block_number: cpu_block.block_number
|
||||||
for gpu_block, cpu_block in mapping.items()
|
for gpu_block, cpu_block in mapping.items()
|
||||||
}
|
}
|
||||||
return block_number_mapping
|
# convert to list of tuples once here
|
||||||
|
return list(block_number_mapping.items())
|
||||||
|
|
||||||
def _free_block_table(self, block_table: BlockTable) -> None:
|
def _free_block_table(self, block_table: BlockTable) -> None:
|
||||||
# when using a sliding window, each seq will only use up
|
# when using a sliding window, each seq will only use up
|
||||||
|
@ -243,13 +243,13 @@ class BlockSpaceManagerV2(BlockSpaceManager):
|
|||||||
return AllocStatus.LATER
|
return AllocStatus.LATER
|
||||||
|
|
||||||
def swap_in(self, seq_group: SequenceGroup,
|
def swap_in(self, seq_group: SequenceGroup,
|
||||||
num_lookahead_slots: int) -> Dict[int, int]:
|
num_lookahead_slots: int) -> List[Tuple[int, int]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
|
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_num_free_gpu_blocks(self) -> int:
|
def get_num_free_gpu_blocks(self) -> int:
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import enum
|
import enum
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Dict, List
|
from typing import List
|
||||||
from typing import Sequence as GenericSequence
|
from typing import Sequence as GenericSequence
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
@ -69,7 +69,7 @@ class BlockSpaceManager(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def swap_in(self, seq_group: SequenceGroup,
|
def swap_in(self, seq_group: SequenceGroup,
|
||||||
num_lookahead_slots: int) -> Dict[int, int]:
|
num_lookahead_slots: int) -> List[Tuple[int, int]]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -77,7 +77,7 @@ class BlockSpaceManager(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -117,10 +117,10 @@ class SchedulerOutputs:
|
|||||||
num_prefill_groups: int
|
num_prefill_groups: int
|
||||||
# Total number of batched tokens.
|
# Total number of batched tokens.
|
||||||
num_batched_tokens: int
|
num_batched_tokens: int
|
||||||
# Blocks to swap in. Dict of CPU -> GPU block number.
|
# Blocks to swap in. List of CPU -> GPU block number.
|
||||||
blocks_to_swap_in: Dict[int, int]
|
blocks_to_swap_in: List[Tuple[int, int]]
|
||||||
# Blocks to swap out. Dict of GPU -> CPU block number.
|
# Blocks to swap out. List of GPU -> CPU block number.
|
||||||
blocks_to_swap_out: Dict[int, int]
|
blocks_to_swap_out: List[Tuple[int, int]]
|
||||||
# Blocks to copy. Source to dest block.
|
# Blocks to copy. Source to dest block.
|
||||||
blocks_to_copy: List[Tuple[int, int]]
|
blocks_to_copy: List[Tuple[int, int]]
|
||||||
# Sequence groups that are going to be ignored.
|
# Sequence groups that are going to be ignored.
|
||||||
@ -174,7 +174,7 @@ class SchedulerRunningOutputs:
|
|||||||
# Sequences that are swapped out.
|
# Sequences that are swapped out.
|
||||||
swapped_out: List[SequenceGroup]
|
swapped_out: List[SequenceGroup]
|
||||||
# The blocks to swap out.
|
# The blocks to swap out.
|
||||||
blocks_to_swap_out: Dict[int, int]
|
blocks_to_swap_out: List[Tuple[int, int]]
|
||||||
# The blocks to copy.
|
# The blocks to copy.
|
||||||
blocks_to_copy: List[Tuple[int, int]]
|
blocks_to_copy: List[Tuple[int, int]]
|
||||||
# The number of slots for lookahead decoding.
|
# The number of slots for lookahead decoding.
|
||||||
@ -187,7 +187,7 @@ class SchedulerRunningOutputs:
|
|||||||
prefill_seq_groups=[],
|
prefill_seq_groups=[],
|
||||||
preempted=[],
|
preempted=[],
|
||||||
swapped_out=[],
|
swapped_out=[],
|
||||||
blocks_to_swap_out={},
|
blocks_to_swap_out=[],
|
||||||
blocks_to_copy=[],
|
blocks_to_copy=[],
|
||||||
num_lookahead_slots=0,
|
num_lookahead_slots=0,
|
||||||
)
|
)
|
||||||
@ -206,7 +206,7 @@ class SchedulerSwappedInOutputs:
|
|||||||
# phase. I.e., it means the prefill has been chunked.
|
# phase. I.e., it means the prefill has been chunked.
|
||||||
prefill_seq_groups: List[SequenceGroup]
|
prefill_seq_groups: List[SequenceGroup]
|
||||||
# The blocks to swap in.
|
# The blocks to swap in.
|
||||||
blocks_to_swap_in: Dict[int, int]
|
blocks_to_swap_in: List[Tuple[int, int]]
|
||||||
# The blocks to copy.
|
# The blocks to copy.
|
||||||
blocks_to_copy: List[Tuple[int, int]]
|
blocks_to_copy: List[Tuple[int, int]]
|
||||||
# The number of slots for lookahead decoding.
|
# The number of slots for lookahead decoding.
|
||||||
@ -219,7 +219,7 @@ class SchedulerSwappedInOutputs:
|
|||||||
return SchedulerSwappedInOutputs(
|
return SchedulerSwappedInOutputs(
|
||||||
decode_seq_groups=[],
|
decode_seq_groups=[],
|
||||||
prefill_seq_groups=[],
|
prefill_seq_groups=[],
|
||||||
blocks_to_swap_in={},
|
blocks_to_swap_in=[],
|
||||||
blocks_to_copy=[],
|
blocks_to_copy=[],
|
||||||
num_lookahead_slots=0,
|
num_lookahead_slots=0,
|
||||||
infeasible_seq_groups=[],
|
infeasible_seq_groups=[],
|
||||||
@ -392,7 +392,7 @@ class Scheduler:
|
|||||||
scheduling and SchedulerRunningOutputs.
|
scheduling and SchedulerRunningOutputs.
|
||||||
"""
|
"""
|
||||||
# Blocks that need to be swapped or copied before model execution.
|
# Blocks that need to be swapped or copied before model execution.
|
||||||
blocks_to_swap_out: Dict[int, int] = {}
|
blocks_to_swap_out: List[Tuple[int, int]] = []
|
||||||
blocks_to_copy: List[Tuple[int, int]] = []
|
blocks_to_copy: List[Tuple[int, int]] = []
|
||||||
|
|
||||||
decode_seq_groups: List[ScheduledSequenceGroup] = []
|
decode_seq_groups: List[ScheduledSequenceGroup] = []
|
||||||
@ -509,7 +509,7 @@ class Scheduler:
|
|||||||
SchedulerSwappedInOutputs.
|
SchedulerSwappedInOutputs.
|
||||||
"""
|
"""
|
||||||
# Blocks that need to be swapped or copied before model execution.
|
# Blocks that need to be swapped or copied before model execution.
|
||||||
blocks_to_swap_in: Dict[int, int] = {}
|
blocks_to_swap_in: List[Tuple[int, int]] = []
|
||||||
blocks_to_copy: List[Tuple[int, int]] = []
|
blocks_to_copy: List[Tuple[int, int]] = []
|
||||||
decode_seq_groups: List[ScheduledSequenceGroup] = []
|
decode_seq_groups: List[ScheduledSequenceGroup] = []
|
||||||
prefill_seq_groups: List[ScheduledSequenceGroup] = []
|
prefill_seq_groups: List[ScheduledSequenceGroup] = []
|
||||||
@ -1032,7 +1032,7 @@ class Scheduler:
|
|||||||
def _preempt(
|
def _preempt(
|
||||||
self,
|
self,
|
||||||
seq_group: SequenceGroup,
|
seq_group: SequenceGroup,
|
||||||
blocks_to_swap_out: Dict[int, int],
|
blocks_to_swap_out: List[Tuple[int, int]],
|
||||||
preemption_mode: Optional[PreemptionMode] = None,
|
preemption_mode: Optional[PreemptionMode] = None,
|
||||||
) -> PreemptionMode:
|
) -> PreemptionMode:
|
||||||
# If preemption mode is not specified, we determine the mode as follows:
|
# If preemption mode is not specified, we determine the mode as follows:
|
||||||
@ -1073,24 +1073,24 @@ class Scheduler:
|
|||||||
def _preempt_by_swap(
|
def _preempt_by_swap(
|
||||||
self,
|
self,
|
||||||
seq_group: SequenceGroup,
|
seq_group: SequenceGroup,
|
||||||
blocks_to_swap_out: Dict[int, int],
|
blocks_to_swap_out: List[Tuple[int, int]],
|
||||||
) -> None:
|
) -> None:
|
||||||
self._swap_out(seq_group, blocks_to_swap_out)
|
self._swap_out(seq_group, blocks_to_swap_out)
|
||||||
|
|
||||||
def _swap_in(
|
def _swap_in(
|
||||||
self,
|
self,
|
||||||
seq_group: SequenceGroup,
|
seq_group: SequenceGroup,
|
||||||
blocks_to_swap_in: Dict[int, int],
|
blocks_to_swap_in: List[Tuple[int, int]],
|
||||||
) -> None:
|
) -> None:
|
||||||
mapping = self.block_manager.swap_in(seq_group)
|
mapping = self.block_manager.swap_in(seq_group)
|
||||||
blocks_to_swap_in.update(mapping)
|
blocks_to_swap_in.extend(mapping)
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
||||||
seq.status = SequenceStatus.RUNNING
|
seq.status = SequenceStatus.RUNNING
|
||||||
|
|
||||||
def _swap_out(
|
def _swap_out(
|
||||||
self,
|
self,
|
||||||
seq_group: SequenceGroup,
|
seq_group: SequenceGroup,
|
||||||
blocks_to_swap_out: Dict[int, int],
|
blocks_to_swap_out: List[Tuple[int, int]],
|
||||||
) -> None:
|
) -> None:
|
||||||
if not self.block_manager.can_swap_out(seq_group):
|
if not self.block_manager.can_swap_out(seq_group):
|
||||||
# FIXME(woosuk): Abort the sequence group instead of aborting the
|
# FIXME(woosuk): Abort the sequence group instead of aborting the
|
||||||
@ -1099,7 +1099,7 @@ class Scheduler:
|
|||||||
"Aborted due to the lack of CPU swap space. Please increase "
|
"Aborted due to the lack of CPU swap space. Please increase "
|
||||||
"the swap space to avoid this error.")
|
"the swap space to avoid this error.")
|
||||||
mapping = self.block_manager.swap_out(seq_group)
|
mapping = self.block_manager.swap_out(seq_group)
|
||||||
blocks_to_swap_out.update(mapping)
|
blocks_to_swap_out.extend(mapping)
|
||||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||||
seq.status = SequenceStatus.SWAPPED
|
seq.status = SequenceStatus.SWAPPED
|
||||||
|
|
||||||
|
@ -741,10 +741,10 @@ class ExecuteModelRequest:
|
|||||||
"""The model execution request."""
|
"""The model execution request."""
|
||||||
# The sequence group metadata list.
|
# The sequence group metadata list.
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata]
|
seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||||
# Blocks to swap in. Dict of CPU -> GPU block number.
|
# Blocks to swap in. List of CPU -> GPU block number.
|
||||||
blocks_to_swap_in: Dict[int, int] = field(default_factory=dict)
|
blocks_to_swap_in: List[Tuple[int, int]] = field(default_factory=list)
|
||||||
# Blocks to swap out. Dict of GPU -> CPU block number.
|
# Blocks to swap out. List of GPU -> CPU block number.
|
||||||
blocks_to_swap_out: Dict[int, int] = field(default_factory=dict)
|
blocks_to_swap_out: List[Tuple[int, int]] = field(default_factory=list)
|
||||||
# Blocks to copy. Source to dest block.
|
# Blocks to copy. Source to dest block.
|
||||||
blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
|
blocks_to_copy: List[Tuple[int, int]] = field(default_factory=list)
|
||||||
# The number of slots for lookahead decoding.
|
# The number of slots for lookahead decoding.
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
"""CacheEngine class for managing the KV cache."""
|
"""CacheEngine class for managing the KV cache."""
|
||||||
from typing import Dict, List
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -67,12 +67,12 @@ class CacheEngine:
|
|||||||
device=device))
|
device=device))
|
||||||
return kv_cache
|
return kv_cache
|
||||||
|
|
||||||
def swap_in(self, src_to_dst: Dict[int, int]) -> None:
|
def swap_in(self, src_to_dst: torch.Tensor) -> None:
|
||||||
for i in range(self.num_layers):
|
for i in range(self.num_layers):
|
||||||
self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
|
self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
|
||||||
src_to_dst)
|
src_to_dst)
|
||||||
|
|
||||||
def swap_out(self, src_to_dst: Dict[int, int]) -> None:
|
def swap_out(self, src_to_dst: torch.Tensor) -> None:
|
||||||
for i in range(self.num_layers):
|
for i in range(self.num_layers):
|
||||||
self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
|
self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
|
||||||
src_to_dst)
|
src_to_dst)
|
||||||
|
@ -195,15 +195,14 @@ class Worker(WorkerBase):
|
|||||||
|
|
||||||
def cache_swap(
|
def cache_swap(
|
||||||
self,
|
self,
|
||||||
blocks_to_swap_in: Dict[int, int],
|
blocks_to_swap_in: torch.Tensor,
|
||||||
blocks_to_swap_out: Dict[int, int],
|
blocks_to_swap_out: torch.Tensor,
|
||||||
blocks_to_copy: torch.Tensor,
|
blocks_to_copy: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
# Issue cache operations.
|
# Issue cache operations.
|
||||||
# TODO(woosuk): Profile swapping overhead and optimize if needed.
|
if blocks_to_swap_in.numel() > 0:
|
||||||
if blocks_to_swap_in:
|
|
||||||
self.cache_engine.swap_in(blocks_to_swap_in)
|
self.cache_engine.swap_in(blocks_to_swap_in)
|
||||||
if blocks_to_swap_out:
|
if blocks_to_swap_out.numel() > 0:
|
||||||
self.cache_engine.swap_out(blocks_to_swap_out)
|
self.cache_engine.swap_out(blocks_to_swap_out)
|
||||||
if blocks_to_copy.numel() > 0:
|
if blocks_to_copy.numel() > 0:
|
||||||
self.cache_engine.copy(blocks_to_copy)
|
self.cache_engine.copy(blocks_to_copy)
|
||||||
@ -219,12 +218,26 @@ class Worker(WorkerBase):
|
|||||||
else:
|
else:
|
||||||
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
seq_group_metadata_list = execute_model_req.seq_group_metadata_list
|
||||||
|
|
||||||
|
blocks_to_swap_in: torch.Tensor
|
||||||
|
blocks_to_swap_out: torch.Tensor
|
||||||
|
blocks_to_copy: torch.Tensor
|
||||||
if self.is_driver_worker:
|
if self.is_driver_worker:
|
||||||
assert seq_group_metadata_list is not None
|
assert seq_group_metadata_list is not None
|
||||||
assert execute_model_req is not None
|
assert execute_model_req is not None
|
||||||
num_seq_groups = len(seq_group_metadata_list)
|
num_seq_groups = len(seq_group_metadata_list)
|
||||||
blocks_to_swap_in = execute_model_req.blocks_to_swap_in
|
# `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors.
|
||||||
blocks_to_swap_out = execute_model_req.blocks_to_swap_out
|
# they contain parameters to launch cudamemcpyasync.
|
||||||
|
blocks_to_swap_in = torch.tensor(
|
||||||
|
execute_model_req.blocks_to_swap_in,
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.int64).view(-1, 2)
|
||||||
|
blocks_to_swap_out = torch.tensor(
|
||||||
|
execute_model_req.blocks_to_swap_out,
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.int64).view(-1, 2)
|
||||||
|
# `blocks_to_copy` is a gpu tensor. The src and tgt of
|
||||||
|
# blocks to copy are in the same device, and `blocks_to_copy`
|
||||||
|
# can be used directly within cuda kernels.
|
||||||
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
|
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
dtype=torch.int64).view(-1, 2)
|
dtype=torch.int64).view(-1, 2)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user