diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 973efcbf..ff4058a3 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -89,6 +89,17 @@ def _is_sampling_metadata_changed(model_runner, sampling_metadata_before) +def _is_req_state_block_table_match(model_runner, req_id: str) -> bool: + req_index = model_runner.input_batch.req_id_to_index[req_id] + block_table = model_runner.input_batch.block_table + req_state = model_runner.requests[req_id] + if block_table.num_blocks_per_row[req_index] != len(req_state.block_ids): + return False + num_blocks = block_table.num_blocks_per_row[req_index] + return (block_table.block_table_np[req_index, :num_blocks] == + req_state.block_ids).all() + + def test_update_states_new_request(model_runner): req_id = "req_0" @@ -100,6 +111,7 @@ def test_update_states_new_request(model_runner): assert _is_sampling_metadata_changed(model_runner, metadata_before) assert _is_req_added(model_runner, req_id) assert _is_req_scheduled(model_runner, req_id) + assert _is_req_state_block_table_match(model_runner, req_id) def test_update_states_request_finished(model_runner): @@ -185,6 +197,7 @@ def test_update_states_request_resumed(model_runner): assert _is_sampling_metadata_changed(model_runner, metadata_before) assert _is_req_added(model_runner, req_id) assert _is_req_scheduled(model_runner, req_id) + assert _is_req_state_block_table_match(model_runner, req_id) def test_update_states_no_changes(model_runner): @@ -215,6 +228,7 @@ def test_update_states_no_changes(model_runner): assert not _is_sampling_metadata_changed(model_runner, metadata_before) assert _is_req_added(model_runner, req_id) assert _is_req_scheduled(model_runner, req_id) + assert _is_req_state_block_table_match(model_runner, req_id) def test_update_states_request_unscheduled(model_runner): diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index 669175f5..830cca10 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -15,13 +15,11 @@ class BlockTable: def __init__( self, max_num_reqs: int, - max_model_len: int, max_num_blocks_per_req: int, pin_memory: bool, device: torch.device, ): self.max_num_reqs = max_num_reqs - self.max_model_len = max_model_len self.max_num_blocks_per_req = max_num_blocks_per_req self.pin_memory = pin_memory self.device = device @@ -42,18 +40,19 @@ class BlockTable: def append_row( self, - row_idx: int, - start: int, block_ids: List[int], + row_idx: int, ) -> None: if not block_ids: return num_blocks = len(block_ids) + start = self.num_blocks_per_row[row_idx] + self.num_blocks_per_row[row_idx] += num_blocks self.block_table_np[row_idx, start:start + num_blocks] = block_ids - self.num_blocks_per_row[row_idx] = start + num_blocks - def add_row(self, row_idx: int, block_ids: List[int]) -> None: - self.append_row(row_idx, 0, block_ids) + def add_row(self, block_ids: List[int], row_idx: int) -> None: + self.num_blocks_per_row[row_idx] = 0 + self.append_row(block_ids, row_idx) def move_row(self, src: int, tgt: int) -> None: num_blocks = self.num_blocks_per_row[src] diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 1b6ea559..788a3522 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -92,7 +92,6 @@ class InputBatch: # Block table. self.block_table = BlockTable( max_num_reqs=max_num_reqs, - max_model_len=max_model_len, max_num_blocks_per_req=max_num_blocks_per_req, pin_memory=pin_memory, device=device, @@ -249,7 +248,7 @@ class InputBatch: self.num_tokens_no_spec[req_index] = request.num_tokens self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens - self.block_table.add_row(req_index, request.block_ids) + self.block_table.add_row(request.block_ids, req_index) sampling_params = request.sampling_params if sampling_params.sampling_type == SamplingType.GREEDY: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e255becb..0215b273 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -399,10 +399,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( num_computed_tokens) - start_index = (len(req_state.block_ids) - - len(req_data.new_block_ids)) - self.input_batch.block_table.append_row(req_index, start_index, - req_data.new_block_ids) + self.input_batch.block_table.append_row(req_data.new_block_ids, + req_index) # Add new_token_ids to token_ids_cpu. start_token_index = num_computed_tokens end_token_index = num_computed_tokens + len(req_data.new_token_ids) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index d16a0a41..2c6a0371 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -247,10 +247,8 @@ class TPUModelRunner: # Update the persistent batch. self.input_batch.num_computed_tokens_cpu[req_index] = ( req_data.num_computed_tokens) - start_index = len(req_state.block_ids) - len( - req_data.new_block_ids) - self.input_batch.block_table.append_row(req_index, start_index, - req_data.new_block_ids) + self.input_batch.block_table.append_row(req_data.new_block_ids, + req_index) # Add the new or resumed requests to the persistent batch. # The smaller empty indices are filled first.