[MLA] Simplification to batch P/D reordering (#16673)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-04-17 13:12:09 -07:00 committed by GitHub
parent e4755f7fac
commit 0377b8310b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 16 deletions

View File

@ -415,20 +415,18 @@ class MLACommonMetadataBuilder(Generic[M]):
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
first_prefill = 0
modified_batch = False
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
if decodes[num_decodes - i] >= num_decodes:
input_batch.swap_states(prefills[first_prefill],
decodes[num_decodes - i])
first_prefill += 1
modified_batch = True
else:
decode_idx = decodes[num_decodes - i]
if decode_idx < num_decodes:
break
input_batch.swap_states(prefills[i - 1], decode_idx)
modified_batch = True
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this

View File

@ -458,7 +458,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if removed_req_indices:
self.input_batch.condense(removed_req_indices)
if batch_changed:
# Some attention backends (namely MLA) may want to separate requests
# based on if the attention computation will be compute-bound or
# memory-bound. This gives them a hook to do that.
batch_reordered = self.attn_metadata_builder.reorder_batch(
self.input_batch, scheduler_output)
if batch_changed or batch_reordered:
self.input_batch.refresh_sampling_metadata()
def _prepare_inputs(
@ -471,14 +477,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_reqs = self.input_batch.num_reqs
assert num_reqs > 0
# Some attention backends (namely MLA) may want to separate requests
# based on if the attention computation will be compute-bound or
# memory-bound. This gives them a hook to do that.
modified_batch = self.attn_metadata_builder.reorder_batch(
self.input_batch, scheduler_output)
if modified_batch:
self.input_batch.refresh_sampling_metadata()
# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
self.input_batch.block_table.commit(num_reqs)