[v1] Cleanup the BlockTable in InputBatch (#13977)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang 2025-03-01 03:03:16 +08:00 committed by GitHub
parent c3b6559a10
commit e7bd944e08
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 25 additions and 17 deletions

View File

@ -89,6 +89,17 @@ def _is_sampling_metadata_changed(model_runner,
sampling_metadata_before)
def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
req_index = model_runner.input_batch.req_id_to_index[req_id]
block_table = model_runner.input_batch.block_table
req_state = model_runner.requests[req_id]
if block_table.num_blocks_per_row[req_index] != len(req_state.block_ids):
return False
num_blocks = block_table.num_blocks_per_row[req_index]
return (block_table.block_table_np[req_index, :num_blocks] ==
req_state.block_ids).all()
def test_update_states_new_request(model_runner):
req_id = "req_0"
@ -100,6 +111,7 @@ def test_update_states_new_request(model_runner):
assert _is_sampling_metadata_changed(model_runner, metadata_before)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)
assert _is_req_state_block_table_match(model_runner, req_id)
def test_update_states_request_finished(model_runner):
@ -185,6 +197,7 @@ def test_update_states_request_resumed(model_runner):
assert _is_sampling_metadata_changed(model_runner, metadata_before)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)
assert _is_req_state_block_table_match(model_runner, req_id)
def test_update_states_no_changes(model_runner):
@ -215,6 +228,7 @@ def test_update_states_no_changes(model_runner):
assert not _is_sampling_metadata_changed(model_runner, metadata_before)
assert _is_req_added(model_runner, req_id)
assert _is_req_scheduled(model_runner, req_id)
assert _is_req_state_block_table_match(model_runner, req_id)
def test_update_states_request_unscheduled(model_runner):

View File

@ -15,13 +15,11 @@ class BlockTable:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_blocks_per_req: int,
pin_memory: bool,
device: torch.device,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_blocks_per_req = max_num_blocks_per_req
self.pin_memory = pin_memory
self.device = device
@ -42,18 +40,19 @@ class BlockTable:
def append_row(
self,
row_idx: int,
start: int,
block_ids: List[int],
row_idx: int,
) -> None:
if not block_ids:
return
num_blocks = len(block_ids)
start = self.num_blocks_per_row[row_idx]
self.num_blocks_per_row[row_idx] += num_blocks
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
self.num_blocks_per_row[row_idx] = start + num_blocks
def add_row(self, row_idx: int, block_ids: List[int]) -> None:
self.append_row(row_idx, 0, block_ids)
def add_row(self, block_ids: List[int], row_idx: int) -> None:
self.num_blocks_per_row[row_idx] = 0
self.append_row(block_ids, row_idx)
def move_row(self, src: int, tgt: int) -> None:
num_blocks = self.num_blocks_per_row[src]

View File

@ -92,7 +92,6 @@ class InputBatch:
# Block table.
self.block_table = BlockTable(
max_num_reqs=max_num_reqs,
max_model_len=max_model_len,
max_num_blocks_per_req=max_num_blocks_per_req,
pin_memory=pin_memory,
device=device,
@ -249,7 +248,7 @@ class InputBatch:
self.num_tokens_no_spec[req_index] = request.num_tokens
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.block_table.add_row(req_index, request.block_ids)
self.block_table.add_row(request.block_ids, req_index)
sampling_params = request.sampling_params
if sampling_params.sampling_type == SamplingType.GREEDY:

View File

@ -399,10 +399,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = (
num_computed_tokens)
start_index = (len(req_state.block_ids) -
len(req_data.new_block_ids))
self.input_batch.block_table.append_row(req_index, start_index,
req_data.new_block_ids)
self.input_batch.block_table.append_row(req_data.new_block_ids,
req_index)
# Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + len(req_data.new_token_ids)

View File

@ -247,10 +247,8 @@ class TPUModelRunner:
# Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = (
req_data.num_computed_tokens)
start_index = len(req_state.block_ids) - len(
req_data.new_block_ids)
self.input_batch.block_table.append_row(req_index, start_index,
req_data.new_block_ids)
self.input_batch.block_table.append_row(req_data.new_block_ids,
req_index)
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.