[v1] Cleanup the BlockTable in InputBatch (#13977)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
parent
c3b6559a10
commit
e7bd944e08
@ -89,6 +89,17 @@ def _is_sampling_metadata_changed(model_runner,
|
||||
sampling_metadata_before)
|
||||
|
||||
|
||||
def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
|
||||
req_index = model_runner.input_batch.req_id_to_index[req_id]
|
||||
block_table = model_runner.input_batch.block_table
|
||||
req_state = model_runner.requests[req_id]
|
||||
if block_table.num_blocks_per_row[req_index] != len(req_state.block_ids):
|
||||
return False
|
||||
num_blocks = block_table.num_blocks_per_row[req_index]
|
||||
return (block_table.block_table_np[req_index, :num_blocks] ==
|
||||
req_state.block_ids).all()
|
||||
|
||||
|
||||
def test_update_states_new_request(model_runner):
|
||||
req_id = "req_0"
|
||||
|
||||
@ -100,6 +111,7 @@ def test_update_states_new_request(model_runner):
|
||||
assert _is_sampling_metadata_changed(model_runner, metadata_before)
|
||||
assert _is_req_added(model_runner, req_id)
|
||||
assert _is_req_scheduled(model_runner, req_id)
|
||||
assert _is_req_state_block_table_match(model_runner, req_id)
|
||||
|
||||
|
||||
def test_update_states_request_finished(model_runner):
|
||||
@ -185,6 +197,7 @@ def test_update_states_request_resumed(model_runner):
|
||||
assert _is_sampling_metadata_changed(model_runner, metadata_before)
|
||||
assert _is_req_added(model_runner, req_id)
|
||||
assert _is_req_scheduled(model_runner, req_id)
|
||||
assert _is_req_state_block_table_match(model_runner, req_id)
|
||||
|
||||
|
||||
def test_update_states_no_changes(model_runner):
|
||||
@ -215,6 +228,7 @@ def test_update_states_no_changes(model_runner):
|
||||
assert not _is_sampling_metadata_changed(model_runner, metadata_before)
|
||||
assert _is_req_added(model_runner, req_id)
|
||||
assert _is_req_scheduled(model_runner, req_id)
|
||||
assert _is_req_state_block_table_match(model_runner, req_id)
|
||||
|
||||
|
||||
def test_update_states_request_unscheduled(model_runner):
|
||||
|
@ -15,13 +15,11 @@ class BlockTable:
|
||||
def __init__(
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_blocks_per_req: int,
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_blocks_per_req = max_num_blocks_per_req
|
||||
self.pin_memory = pin_memory
|
||||
self.device = device
|
||||
@ -42,18 +40,19 @@ class BlockTable:
|
||||
|
||||
def append_row(
|
||||
self,
|
||||
row_idx: int,
|
||||
start: int,
|
||||
block_ids: List[int],
|
||||
row_idx: int,
|
||||
) -> None:
|
||||
if not block_ids:
|
||||
return
|
||||
num_blocks = len(block_ids)
|
||||
start = self.num_blocks_per_row[row_idx]
|
||||
self.num_blocks_per_row[row_idx] += num_blocks
|
||||
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
|
||||
self.num_blocks_per_row[row_idx] = start + num_blocks
|
||||
|
||||
def add_row(self, row_idx: int, block_ids: List[int]) -> None:
|
||||
self.append_row(row_idx, 0, block_ids)
|
||||
def add_row(self, block_ids: List[int], row_idx: int) -> None:
|
||||
self.num_blocks_per_row[row_idx] = 0
|
||||
self.append_row(block_ids, row_idx)
|
||||
|
||||
def move_row(self, src: int, tgt: int) -> None:
|
||||
num_blocks = self.num_blocks_per_row[src]
|
||||
|
@ -92,7 +92,6 @@ class InputBatch:
|
||||
# Block table.
|
||||
self.block_table = BlockTable(
|
||||
max_num_reqs=max_num_reqs,
|
||||
max_model_len=max_model_len,
|
||||
max_num_blocks_per_req=max_num_blocks_per_req,
|
||||
pin_memory=pin_memory,
|
||||
device=device,
|
||||
@ -249,7 +248,7 @@ class InputBatch:
|
||||
self.num_tokens_no_spec[req_index] = request.num_tokens
|
||||
|
||||
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
|
||||
self.block_table.add_row(req_index, request.block_ids)
|
||||
self.block_table.add_row(request.block_ids, req_index)
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
if sampling_params.sampling_type == SamplingType.GREEDY:
|
||||
|
@ -399,10 +399,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Update the persistent batch.
|
||||
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
||||
num_computed_tokens)
|
||||
start_index = (len(req_state.block_ids) -
|
||||
len(req_data.new_block_ids))
|
||||
self.input_batch.block_table.append_row(req_index, start_index,
|
||||
req_data.new_block_ids)
|
||||
self.input_batch.block_table.append_row(req_data.new_block_ids,
|
||||
req_index)
|
||||
# Add new_token_ids to token_ids_cpu.
|
||||
start_token_index = num_computed_tokens
|
||||
end_token_index = num_computed_tokens + len(req_data.new_token_ids)
|
||||
|
@ -247,10 +247,8 @@ class TPUModelRunner:
|
||||
# Update the persistent batch.
|
||||
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
||||
req_data.num_computed_tokens)
|
||||
start_index = len(req_state.block_ids) - len(
|
||||
req_data.new_block_ids)
|
||||
self.input_batch.block_table.append_row(req_index, start_index,
|
||||
req_data.new_block_ids)
|
||||
self.input_batch.block_table.append_row(req_data.new_block_ids,
|
||||
req_index)
|
||||
|
||||
# Add the new or resumed requests to the persistent batch.
|
||||
# The smaller empty indices are filled first.
|
||||
|
Loading…
x
Reference in New Issue
Block a user