[MLA] Simplification to batch P/D reordering (#16673)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
e4755f7fac
commit
0377b8310b
@ -415,20 +415,18 @@ class MLACommonMetadataBuilder(Generic[M]):
|
||||
# the above loop
|
||||
num_decodes = len(decodes)
|
||||
num_prefills = len(prefills)
|
||||
first_prefill = 0
|
||||
modified_batch = False
|
||||
|
||||
for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||
# If the decode is at the "back" of the batch, i, we can swap it
|
||||
# with the prefill closest to the front of the batch
|
||||
if decodes[num_decodes - i] >= num_decodes:
|
||||
input_batch.swap_states(prefills[first_prefill],
|
||||
decodes[num_decodes - i])
|
||||
first_prefill += 1
|
||||
modified_batch = True
|
||||
else:
|
||||
decode_idx = decodes[num_decodes - i]
|
||||
if decode_idx < num_decodes:
|
||||
break
|
||||
|
||||
input_batch.swap_states(prefills[i - 1], decode_idx)
|
||||
modified_batch = True
|
||||
|
||||
# Save for next `build` call
|
||||
# TODO(lucas): this is a bit of a hack, we should probably have a
|
||||
# better way of doing this
|
||||
|
@ -458,7 +458,13 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if removed_req_indices:
|
||||
self.input_batch.condense(removed_req_indices)
|
||||
|
||||
if batch_changed:
|
||||
# Some attention backends (namely MLA) may want to separate requests
|
||||
# based on if the attention computation will be compute-bound or
|
||||
# memory-bound. This gives them a hook to do that.
|
||||
batch_reordered = self.attn_metadata_builder.reorder_batch(
|
||||
self.input_batch, scheduler_output)
|
||||
|
||||
if batch_changed or batch_reordered:
|
||||
self.input_batch.refresh_sampling_metadata()
|
||||
|
||||
def _prepare_inputs(
|
||||
@ -471,14 +477,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
assert num_reqs > 0
|
||||
|
||||
# Some attention backends (namely MLA) may want to separate requests
|
||||
# based on if the attention computation will be compute-bound or
|
||||
# memory-bound. This gives them a hook to do that.
|
||||
modified_batch = self.attn_metadata_builder.reorder_batch(
|
||||
self.input_batch, scheduler_output)
|
||||
if modified_batch:
|
||||
self.input_batch.refresh_sampling_metadata()
|
||||
|
||||
# OPTIMIZATION: Start copying the block table first.
|
||||
# This way, we can overlap the copy with the following CPU operations.
|
||||
self.input_batch.block_table.commit(num_reqs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user