diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 8c7179ba..b77e9525 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -415,20 +415,18 @@ class MLACommonMetadataBuilder(Generic[M]): # the above loop num_decodes = len(decodes) num_prefills = len(prefills) - first_prefill = 0 modified_batch = False for i in range(1, min(num_decodes, num_prefills) + 1): # If the decode is at the "back" of the batch, i, we can swap it # with the prefill closest to the front of the batch - if decodes[num_decodes - i] >= num_decodes: - input_batch.swap_states(prefills[first_prefill], - decodes[num_decodes - i]) - first_prefill += 1 - modified_batch = True - else: + decode_idx = decodes[num_decodes - i] + if decode_idx < num_decodes: break + input_batch.swap_states(prefills[i - 1], decode_idx) + modified_batch = True + # Save for next `build` call # TODO(lucas): this is a bit of a hack, we should probably have a # better way of doing this diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index bfdb0f72..baf0dfb9 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -458,7 +458,13 @@ class GPUModelRunner(LoRAModelRunnerMixin): if removed_req_indices: self.input_batch.condense(removed_req_indices) - if batch_changed: + # Some attention backends (namely MLA) may want to separate requests + # based on if the attention computation will be compute-bound or + # memory-bound. This gives them a hook to do that. + batch_reordered = self.attn_metadata_builder.reorder_batch( + self.input_batch, scheduler_output) + + if batch_changed or batch_reordered: self.input_batch.refresh_sampling_metadata() def _prepare_inputs( @@ -471,14 +477,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_reqs = self.input_batch.num_reqs assert num_reqs > 0 - # Some attention backends (namely MLA) may want to separate requests - # based on if the attention computation will be compute-bound or - # memory-bound. This gives them a hook to do that. - modified_batch = self.attn_metadata_builder.reorder_batch( - self.input_batch, scheduler_output) - if modified_batch: - self.input_batch.refresh_sampling_metadata() - # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations. self.input_batch.block_table.commit(num_reqs)