From f6bb18fd9a19e5e4fb1991339638fc666d06b27a Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 5 Mar 2025 20:10:13 -0500 Subject: [PATCH] [BugFix] MLA + V1, illegal memory access and accuracy issues (#14253) Signed-off-by: Lucas Wilkinson --- tests/v1/worker/test_gpu_input_batch.py | 90 +++++- vllm/v1/attention/backends/flash_attn.py | 4 +- vllm/v1/attention/backends/mla/common.py | 289 +++++++++++-------- vllm/v1/attention/backends/mla/flashmla.py | 58 ++-- vllm/v1/attention/backends/mla/triton_mla.py | 7 +- vllm/v1/worker/gpu_input_batch.py | 25 +- vllm/v1/worker/gpu_model_runner.py | 6 +- 7 files changed, 326 insertions(+), 153 deletions(-) diff --git a/tests/v1/worker/test_gpu_input_batch.py b/tests/v1/worker/test_gpu_input_batch.py index 72ec7370..5f0cb1d3 100644 --- a/tests/v1/worker/test_gpu_input_batch.py +++ b/tests/v1/worker/test_gpu_input_batch.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import inspect from typing import Optional import numpy as np @@ -9,7 +10,8 @@ import torch from vllm.sampling_params import SamplingParams from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.gpu_input_batch import (BlockTable, CachedRequestState, + InputBatch) VOCAB_SIZE = 1024 NUM_OUTPUT_TOKENS = 20 @@ -20,6 +22,34 @@ CUDA_DEVICES = [ MAX_NUM_PROMPT_TOKENS = 64 +def _compare_objs(obj1, obj2): + attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a))) + attr_names = set([ + a[0] for a in attrs + if not (a[0].startswith('__') and a[0].endswith('__')) + ]) + for attr_name in attr_names: + a = getattr(obj1, attr_name) + b = getattr(obj2, attr_name) + + is_same = False + if isinstance(a, torch.Tensor): + if (a.numel() == 0 or b.numel() == 0): + is_same = (a.numel() == 0 and b.numel() == 0) + elif torch.allclose(a, b): + is_same = True + elif isinstance(a, np.ndarray): + if np.allclose(a, b): + is_same = True + elif isinstance(a, (BlockTable, SamplingMetadata)): + _compare_objs(a, b) + is_same = True # if we make it here must be same + elif a == b: + is_same = True + assert is_same, f"Attribute {attr_name} is different"\ + f" in {obj1} and {obj2}: {a} != {b}" + + def _remove_requests( input_batch: InputBatch, batch_size: int, reqs: list[CachedRequestState]) -> tuple[set[str], list[int]]: @@ -254,3 +284,61 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int): assert torch.allclose( expected_sampling_metadata.allowed_token_ids_mask, sampling_metadata.allowed_token_ids_mask) + + +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("batch_size", [32]) +@pytest.mark.parametrize("swap_list", [((0, 1), )]) +def test_swap_states_in_input_batch(device: str, batch_size: int, + swap_list: list): + """ + Tests the logic for managing sampling metadata in the InputBatch. + + This test involves adding a set of requests to the InputBatch, + followed by removing a subset of them. Afterward, the batch is compacted, + and the `make_sampling_metadata` method is invoked on the batch. The + output of `make_sampling_metadata` is then compared against the expected + results to ensure correctness. + """ + input_batch: InputBatch = InputBatch( + max_num_reqs=batch_size, + max_model_len=1024, + max_num_blocks_per_req=10, + device=torch.device(device), + pin_memory=is_pin_memory_available(), + vocab_size=1024, + ) + ref_input_batch: InputBatch = InputBatch( + max_num_reqs=batch_size, + max_model_len=1024, + max_num_blocks_per_req=10, + device=torch.device(device), + pin_memory=is_pin_memory_available(), + vocab_size=1024, + ) + + reqs: list[CachedRequestState] = [] + req_id_reqs = {} + req_id_output_token_ids = {} + # Add requests + for req_index in range(batch_size): + req: CachedRequestState = _construct_cached_request_state(req_index) + input_batch.add_request(req, req_index) + reqs.append(req) + req_id_reqs[req.req_id] = req + req_id_output_token_ids[req.req_id] = req.output_token_ids + + reordered_reqs = reqs.copy() + for swap_pair in swap_list: + reordered_reqs[swap_pair[0]], reordered_reqs[swap_pair[1]] = \ + reordered_reqs[swap_pair[1]], reordered_reqs[swap_pair[0]] + input_batch.swap_states(swap_pair[0], swap_pair[1]) + + for req_index in range(batch_size): + req = reordered_reqs[req_index] + ref_input_batch.add_request(req, req_index) + + input_batch.refresh_sampling_metadata() + ref_input_batch.refresh_sampling_metadata() + + _compare_objs(input_batch, ref_input_batch) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 8bf7f358..e7c2fd41 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -100,8 +100,8 @@ class FlashAttentionMetadataBuilder: self.runner = runner def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput"): - pass + scheduler_output: "SchedulerOutput") -> bool: + return False def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, common_prefix_len: int): diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 824ffcfd..c98262ee 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -275,17 +275,47 @@ class MLACommonBackend(AttentionBackend): @dataclass -class MLACommonMetadata: +class MLACommonPrefillMetadata: + """ Prefill Specific Metadata """ + + @dataclass + class ChunkedContextMetadata: + # New for MLA (compared to FlashAttention) + # For handling chunked prefill + cu_seq_lens: torch.Tensor + starts: torch.Tensor + seq_tot: list[int] + max_seq_lens: list[int] + workspace: torch.Tensor + + # Input positions for rotrary embeddings since for MLA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor + block_table: torch.Tensor + query_start_loc: torch.Tensor + max_query_len: int + chunked_context: Optional[ChunkedContextMetadata] = None + + +@dataclass +class MLACommonDecodeMetadata: + # Input positions for rotrary embeddings since for MLA the rotary + # position embeddings are applied inside the attention backend + input_positions: torch.Tensor + block_table: torch.Tensor + seq_lens: torch.Tensor + + +D = TypeVar("D", bound=MLACommonDecodeMetadata) + + +@dataclass +class MLACommonMetadata(Generic[D]): """Metadata for MLACommon. NOTE: Please read the comment at the top of the file before trying to understand this class """ - # New for MLA (compared to FlashAttention) - # Input positions for rotrary embeddings since for MLA the rotary - # position embeddings are applied inside the attention backend - input_positions: torch.Tensor - # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| @@ -295,30 +325,23 @@ class MLACommonMetadata: # |-- query_len ---| num_actual_tokens: int # Number of tokens excluding padding. - max_query_len: int query_start_loc: torch.Tensor - max_seq_len: int - seq_lens: torch.Tensor - block_table: torch.Tensor slot_mapping: torch.Tensor + # New for MLA (compared to FlashAttention) + # For handling prefill decode split + num_decodes: int + num_decode_tokens: int + num_prefills: int + # For logging. num_input_tokens: int = 0 # Number of tokens including padding. # The dimension of the attention heads head_dim: Optional[int] = None - # New for MLA (compared to FlashAttention) - # For chunked prefill - num_decodes: Optional[int] = None - num_decode_tokens: Optional[int] = None - num_prefills: Optional[int] = None - has_context: bool = False - context_chunk_cu_seq_lens: Optional[torch.Tensor] = None - context_chunk_starts: Optional[torch.Tensor] = None - context_chunk_seq_tot: Optional[list[int]] = None - context_chunk_max_seq_lens: Optional[list[int]] = None - chunked_prefill_workspace: Optional[torch.Tensor] = None + decode: Optional[D] = None + prefill: Optional[MLACommonPrefillMetadata] = None def __post_init__(self): supported_head_sizes = MLACommonBackend.get_supported_head_sizes() @@ -329,10 +352,10 @@ class MLACommonMetadata: f"received {self.head_dim}.") -T = TypeVar("T", bound=MLACommonMetadata) +M = TypeVar("M", bound=MLACommonMetadata) -class MLACommonMetadataBuilder(Generic[T]): +class MLACommonMetadataBuilder(Generic[M]): """ NOTE: Please read the comment at the top of the file before trying to understand this class @@ -340,8 +363,9 @@ class MLACommonMetadataBuilder(Generic[T]): def __init__(self, runner: "GPUModelRunner", - cls: Optional[type[T]] = None): - self.cls = cls if cls is not None else MLACommonMetadata + metadata_cls: Optional[type[M]] = None): + self.metadata_cls = metadata_cls \ + if metadata_cls is not None else MLACommonMetadata self.runner = runner scheduler_config = runner.scheduler_config model_config = runner.model_config @@ -375,7 +399,7 @@ class MLACommonMetadataBuilder(Generic[T]): self.page_size = self.runner.block_size def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput"): + scheduler_output: "SchedulerOutput") -> bool: # We now want to reorder the batch so that the "decode" requests are and # the front and the "prefill" requests are at the using the least amount # swaps possible. (NOTE for now we loosely use "decode" to mean requests @@ -413,6 +437,7 @@ class MLACommonMetadataBuilder(Generic[T]): num_decodes = len(decodes) num_prefills = len(prefills) first_prefill = 0 + modified_batch = False for i in range(1, min(num_decodes, num_prefills) + 1): # If the decode is at the "back" of the batch, i, we can swap it @@ -421,6 +446,7 @@ class MLACommonMetadataBuilder(Generic[T]): input_batch.swap_states(prefills[first_prefill], decodes[num_decodes - i]) first_prefill += 1 + modified_batch = True else: break @@ -432,10 +458,21 @@ class MLACommonMetadataBuilder(Generic[T]): self._num_decode_tokens = num_decode_tokens self._num_prefill_tokens = num_prefill_tokens + return modified_batch + + def _build_decode(self, input_positions: torch.Tensor, + block_table: torch.Tensor, seq_lens: torch.Tensor): + return MLACommonDecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + ) + def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int) -> T: + common_prefix_len: int) -> M: + assert self._num_decodes + self._num_prefills == num_reqs + device = self.runner.device - max_seq_len = self.runner.seq_lens_np[:num_reqs].max() query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to( device, non_blocking=True) seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(device, @@ -447,85 +484,103 @@ class MLACommonMetadataBuilder(Generic[T]): input_positions = self.runner.positions_cpu[:num_actual_tokens].to( device, non_blocking=True).long() - context_chunk_cu_seq_lens = None - context_chunk_starts = None - context_chunk_seq_tot = None - context_chunk_max_seq_lens = None + prefill_metadata = None + if self._num_prefills > 0: + reqs_start = self._num_decodes # prefill_start + tokens_start = self._num_decode_tokens - num_computed_tokens_cpu_tensor = \ - self.runner.input_batch.num_computed_tokens_cpu_tensor[:num_reqs] - context_lens_tensor = \ - num_computed_tokens_cpu_tensor.to(device, non_blocking=True) + context_lens_cpu = self.runner.input_batch.\ + num_computed_tokens_cpu_tensor[reqs_start:num_reqs] + context_lens = context_lens_cpu.to(device, non_blocking=True) - if self.chunked_prefill_enabled and self._num_prefills > 0 \ - and context_lens_tensor[self._num_decodes:].max() > 0: - # NOTE: it is recommend you read the `Chunked Prefill` section in - # the comment at the top of the file before trying to understand - # the following code + chunked_context_metadata = None + if self.chunked_prefill_enabled and self._num_prefills > 0 \ + and context_lens.max() > 0: + # NOTE: it is recommend you read the `Chunked Prefill` section + # in the comment at the top of the file before trying to + # understand the following code - self.has_context = True + num_prefills_with_context = (context_lens > 0).sum().item() - num_prefills_with_context = \ - (context_lens_tensor[self._num_decodes:] > 0).sum().item() + # currently we allocate an equal amount of workspace for each + # prefill in the batch, we could probably use a more advanced + # algorithm here and allocate more workspace to prefills with + # longer context lengths + max_context_chunk = \ + self.chunked_prefill_workspace_size \ + // num_prefills_with_context - # currently we allocate an equal amount of workspace for each - # prefill in the batch, we could probably use a more advanced - # algorithm here and allocate more workspace to prefills with - # longer context lengths - max_context_chunk = \ - self.chunked_prefill_workspace_size // num_prefills_with_context + # align max_context_chunk to page_size by rounding down, + # currently the `gather_cache` kernel cannot handle + # `context_chunk_starts` that are not aligned to page_size + max_context_chunk = round_down(max_context_chunk, + self.page_size) - # align max_context_chunk to page_size by rounding down, - # currently the `gather_cache` kernel cannot handle - # `context_chunk_starts` that are not aligned to page_size - max_context_chunk = round_down(max_context_chunk, self.page_size) - assert max_context_chunk > 0 - num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk) + assert max_context_chunk > 0 + num_chunks = cdiv(context_lens.max(), max_context_chunk) - # if `max_context_chunk = 256`, `num_chunks = 3`, and - # `num_prefills_with_context = 4`, create a tensor that looks like - # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] - context_chunk_starts = \ - torch.arange(num_chunks, device=device, dtype=torch.int32) \ - .unsqueeze(1).expand(-1, self._num_prefills) \ - * max_context_chunk - chunk_ends = torch.min(context_lens_tensor[self._num_decodes:] \ - .unsqueeze(0), context_chunk_starts + max_context_chunk) - chunk_seq_lens = (chunk_ends - context_chunk_starts).clamp(min=0) - _context_chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to( - torch.int32) - zero = torch.zeros(num_chunks, dtype=torch.int32, device=device) \ - .unsqueeze(-1) - context_chunk_cu_seq_lens = \ - torch.cat([zero, _context_chunk_cu_seq_lens], dim=1) - context_chunk_max_seq_lens = \ - chunk_seq_lens.max(dim=1).values.tolist() - context_chunk_seq_tot = chunk_seq_lens.sum(dim=1).tolist() - assert max(context_chunk_seq_tot) <= \ - self.chunked_prefill_workspace_size + # if `max_context_chunk = 256`, `num_chunks = 3`, and + # `num_prefills_with_context = 4`, create a tensor that looks + # like + # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] + chunk_starts = \ + torch.arange(num_chunks, device=device, dtype=torch.int32) \ + .unsqueeze(1).expand(-1, self._num_prefills) \ + * max_context_chunk + chunk_ends = torch.min(context_lens.unsqueeze(0), + chunk_starts + max_context_chunk) + chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) + _chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to( + torch.int32) + zero = torch.zeros(num_chunks, + dtype=torch.int32, + device=device).unsqueeze(-1) - return self.cls( - input_positions=input_positions, + chunked_context_metadata = \ + MLACommonPrefillMetadata.ChunkedContextMetadata( + cu_seq_lens=torch.cat( + [zero, _chunk_cu_seq_lens], dim=1), + starts=chunk_starts, + seq_tot=chunk_seq_lens.sum(dim=1).tolist(), + max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), + workspace=self.chunked_prefill_workspace, + ) + + assert max(chunked_context_metadata.max_seq_lens) <= \ + self.chunked_prefill_workspace_size + + prefill_metadata = MLACommonPrefillMetadata( + input_positions=input_positions[tokens_start:], + block_table=block_table[reqs_start:, ...], + query_start_loc=query_start_loc[reqs_start:] - + query_start_loc[reqs_start], + max_query_len=seq_lens[reqs_start:].max().item(), + chunked_context=chunked_context_metadata, + ) + + decode_metadata = None + if self._num_decodes > 0: + decode_metadata = self._build_decode( + input_positions=input_positions[:self._num_decode_tokens], + block_table=block_table[:self._num_decodes, ...], + seq_lens=seq_lens[:self._num_decodes], + ) + + return self.metadata_cls( num_actual_tokens=num_actual_tokens, - max_query_len=max_query_len, query_start_loc=query_start_loc, - max_seq_len=max_seq_len, - seq_lens=seq_lens, - block_table=block_table, slot_mapping=slot_mapping, head_dim=self.runner.model_config.get_head_size(), # MLACommonMetadata Chunk prefill specific num_decodes=self._num_decodes, num_decode_tokens=self._num_decode_tokens, num_prefills=self._num_prefills, - context_chunk_cu_seq_lens=context_chunk_cu_seq_lens, - context_chunk_starts=context_chunk_starts, - context_chunk_seq_tot=context_chunk_seq_tot, - context_chunk_max_seq_lens=context_chunk_max_seq_lens, + prefill=prefill_metadata, + decode=decode_metadata, ) -class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): +class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): """ NOTE: Please read the comment at the top of the file before trying to understand this class @@ -798,28 +853,24 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, ): - assert attn_metadata.num_prefills is not None - assert attn_metadata.context_chunk_seq_tot is not None - assert attn_metadata.context_chunk_cu_seq_lens is not None - assert attn_metadata.context_chunk_starts is not None - assert attn_metadata.context_chunk_max_seq_lens is not None + assert attn_metadata.prefill is not None + prefill_metadata = attn_metadata.prefill + assert prefill_metadata.chunked_context is not None output = None - iters = len(attn_metadata.context_chunk_seq_tot) - - assert attn_metadata.chunked_prefill_workspace is not None - workspace = attn_metadata.chunked_prefill_workspace + iters = len(prefill_metadata.chunked_context.seq_tot) + workspace = prefill_metadata.chunked_context.workspace for i in range(iters): - toks = attn_metadata.context_chunk_seq_tot[i] + toks = prefill_metadata.chunked_context.seq_tot[i] ops.gather_cache( src_cache=kv_c_and_k_pe_cache, dst=workspace, - block_table=attn_metadata.block_table, - cu_seq_lens=attn_metadata.context_chunk_cu_seq_lens[i], + block_table=prefill_metadata.block_table, + cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i], batch_size=attn_metadata.num_prefills, - seq_starts=attn_metadata.context_chunk_starts[i], + seq_starts=prefill_metadata.chunked_context.starts[i], ) kv_c_normed = workspace[:toks]\ @@ -845,10 +896,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): q=q, k=k, v=v_padded, - cu_seqlens_q=attn_metadata.query_start_loc, - cu_seqlens_k=attn_metadata.context_chunk_cu_seq_lens[i], - max_seqlen_q=attn_metadata.max_query_len, - max_seqlen_k=attn_metadata.context_chunk_max_seq_lens[i], + cu_seqlens_q=prefill_metadata.query_start_loc, + cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i], + max_seqlen_q=prefill_metadata.max_query_len, + max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i], softmax_scale=self.scale, causal=False, # Context is unmasked return_softmax_lse=True, @@ -881,7 +932,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: MLACommonMetadata, ) -> torch.Tensor: - has_context = attn_metadata.has_context + assert attn_metadata.prefill is not None + + has_context = attn_metadata.prefill.chunked_context is not None kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv_nope\ @@ -898,10 +951,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): q=q, k=k, v=v_padded, - cu_seqlens_q=attn_metadata.query_start_loc, - cu_seqlens_k=attn_metadata.query_start_loc, - max_seqlen_q=attn_metadata.max_query_len, - max_seqlen_k=attn_metadata.max_seq_len, + cu_seqlens_q=attn_metadata.prefill.query_start_loc, + cu_seqlens_k=attn_metadata.prefill.query_start_loc, + max_seqlen_q=attn_metadata.prefill.max_query_len, + max_seqlen_k=attn_metadata.prefill.max_query_len, softmax_scale=self.scale, causal=True, return_softmax_lse=has_context, @@ -934,7 +987,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): q_nope: torch.Tensor, q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: T, + attn_metadata: M, ) -> torch.Tensor: raise NotImplementedError @@ -945,7 +998,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): k_c_normed: torch.Tensor, # key in unified attn k_pe: torch.Tensor, # value in unified attn kv_cache: torch.Tensor, - attn_metadata: T, + attn_metadata: M, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -966,7 +1019,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): # Restore head dim (for rotary embedding) k_pe = k_pe.unsqueeze(1) - assert hasattr(attn_metadata, "input_positions") assert attn_metadata.num_decodes is not None and \ attn_metadata.num_prefills is not None and \ @@ -978,28 +1030,27 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] decode_k_pe = k_pe[:num_decode_tokens] - decode_input_positions = \ - attn_metadata.input_positions[:num_decode_tokens] prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:] prefill_k_pe = k_pe[num_decode_tokens:] - prefill_input_positions = \ - attn_metadata.input_positions[num_decode_tokens:] prefill_k_c_normed = k_c_normed[num_decode_tokens:] if has_decode: + assert attn_metadata.decode is not None decode_q_nope = self._q_proj_and_k_up_proj(decode_hs_or_q_c) decode_q_pe = torch.matmul(decode_hs_or_q_c, self.W_QR)\ .view(-1, self.num_heads, self.qk_rope_head_dim) decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( - decode_input_positions, decode_q_pe, decode_k_pe) + attn_metadata.decode.input_positions, decode_q_pe, decode_k_pe) if has_prefill: + assert attn_metadata.prefill is not None prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ .view(-1, self.num_heads, self.qk_head_dim) prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( - prefill_input_positions, prefill_q_pe, prefill_k_pe) + attn_metadata.prefill.input_positions, prefill_q_pe, + prefill_k_pe) # write the latent and rope to kv cache if kv_cache.numel() > 0: diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index b357d714..d5bf9cd2 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -11,6 +11,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, is_flashmla_supported) from vllm.logger import init_logger from vllm.v1.attention.backends.mla.common import (MLACommonBackend, + MLACommonDecodeMetadata, MLACommonImpl, MLACommonMetadata, MLACommonMetadataBuilder) @@ -38,34 +39,41 @@ class FlashMLABackend(MLACommonBackend): @dataclass -class FlashMLAMetadata(MLACommonMetadata): - decode_tile_scheduler_metadata: Optional[tuple[torch.Tensor, - torch.Tensor]] = None - decode_num_splits: Optional[torch.Tensor] = None +class FlashMLADecodeMetadata(MLACommonDecodeMetadata): + tile_scheduler_metadata: tuple[torch.Tensor, torch.Tensor] + num_splits: torch.Tensor + + +@dataclass +class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]): + pass class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): def __init__(self, runner): - super().__init__(runner, cls=FlashMLAMetadata) + super().__init__(runner) self.num_q_heads = self.runner.model_config.get_num_attention_heads( self.runner.parallel_config) - def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int): - m = super().build(num_reqs, num_actual_tokens, max_query_len, - common_prefix_len) + def _build_decode(self, input_positions: torch.Tensor, + block_table: torch.Tensor, + seq_lens: torch.Tensor) -> FlashMLADecodeMetadata: + tile_scheduler_metadata, num_splits = \ + get_mla_metadata( + seq_lens, + self.num_q_heads, + 1, # MQA for the decode path + ) - if m.num_decode_tokens is not None and m.num_decode_tokens > 0: - m.decode_tile_scheduler_metadata, m.decode_num_splits = \ - get_mla_metadata( - m.seq_lens[:m.num_decode_tokens], - self.num_q_heads, - 1, # MQA for the decode path - ) - - return m + return FlashMLADecodeMetadata( + input_positions=input_positions, + block_table=block_table, + seq_lens=seq_lens, + tile_scheduler_metadata=tile_scheduler_metadata, + num_splits=num_splits, + ) class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): @@ -115,6 +123,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): attn_metadata: FlashMLAMetadata, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 + assert attn_metadata.decode is not None + if self.kv_cache_dtype.startswith("fp8"): raise NotImplementedError("FP8 FlashMLA not yet supported") @@ -124,14 +134,12 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): o, _ = flash_mla_with_kvcache( q=q, k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 - block_table=attn_metadata.block_table[:attn_metadata.num_decodes, - ...], - cache_seqlens=attn_metadata.seq_lens[:attn_metadata. - num_decode_tokens], + block_table=attn_metadata.decode.block_table, + cache_seqlens=attn_metadata.decode.seq_lens, head_dim_v=self.kv_lora_rank, - tile_scheduler_metadata=attn_metadata. - decode_tile_scheduler_metadata, - num_splits=attn_metadata.decode_num_splits, + tile_scheduler_metadata=attn_metadata.decode. + tile_scheduler_metadata, + num_splits=attn_metadata.decode.num_splits, softmax_scale=self.scale, causal=True, ) diff --git a/vllm/v1/attention/backends/mla/triton_mla.py b/vllm/v1/attention/backends/mla/triton_mla.py index 3f9b349a..cef7a3a9 100644 --- a/vllm/v1/attention/backends/mla/triton_mla.py +++ b/vllm/v1/attention/backends/mla/triton_mla.py @@ -69,6 +69,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): attn_metadata: MLACommonMetadata, ) -> torch.Tensor: assert kv_c_and_k_pe_cache.numel() > 0 + assert attn_metadata.decode is not None + if self.kv_cache_dtype.startswith("fp8"): raise NotImplementedError("FP8 Triton MLA not yet supported") @@ -104,7 +106,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): # Run MQA decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, - attn_metadata.block_table, attn_metadata.seq_lens, - attn_logits, num_kv_splits, self.scale, PAGE_SIZE) + attn_metadata.decode.block_table, + attn_metadata.decode.seq_lens, attn_logits, + num_kv_splits, self.scale, PAGE_SIZE) return self._v_up_proj_and_o_proj(o) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 02b2aa3e..6239a182 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -383,8 +383,6 @@ class InputBatch: self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1] self.num_tokens[i1], self.num_tokens[i2] =\ self.num_tokens[i2], self.num_tokens[i1] - self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\ - self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...] self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\ self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1] self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\ @@ -406,24 +404,47 @@ class InputBatch: self.min_p_cpu[i1], self.min_p_cpu[i2] =\ self.min_p_cpu[i2], self.min_p_cpu[i1] + # NOTE: the following is unsafe + # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\ + # self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...] + # instead, we need to temporiarily copy the data for one of the indices + # TODO(lucas): optimize this by only copying valid indices + tmp = self.token_ids_cpu[i1, ...].copy() + self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...] + self.token_ids_cpu[i2, ...] = tmp + g1 = self.generators.get(i1) g2 = self.generators.get(i2) if g1 is not None: self.generators[i2] = g1 + else: + self.generators.pop(i2, None) if g2 is not None: self.generators[i1] = g2 + else: + self.generators.pop(i1, None) t1 = self.min_tokens.get(i1) t2 = self.min_tokens.get(i2) if t1 is not None: self.min_tokens[i2] = t1 + else: + self.min_tokens.pop(i2, None) if t2 is not None: self.min_tokens[i1] = t2 + else: + self.min_tokens.pop(i1, None) self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\ self.request_lora_mapping[i2], self.request_lora_mapping[i1] self.logit_bias[i1], self.logit_bias[i2] =\ self.logit_bias[i2], self.logit_bias[i1] + + if self.allowed_token_ids_mask_cpu_tensor is not None: + self.allowed_token_ids_mask_cpu_tensor[i1], \ + self.allowed_token_ids_mask_cpu_tensor[i2] =\ + self.allowed_token_ids_mask_cpu_tensor[i2], \ + self.allowed_token_ids_mask_cpu_tensor[i1] self.block_table.swap_row(i1, i2) def condense(self, empty_req_indices: list[int]) -> None: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a1a50e89..519f38cb 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -456,8 +456,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): # Some attention backends (namely MLA) may want to separate requests # based on if the attention computation will be compute-bound or # memory-bound. This gives them a hook to do that. - self.attn_metadata_builder.reorder_batch(self.input_batch, - scheduler_output) + modified_batch = self.attn_metadata_builder.reorder_batch( + self.input_batch, scheduler_output) + if modified_batch: + self.input_batch.refresh_sampling_metadata() # OPTIMIZATION: Start copying the block table first. # This way, we can overlap the copy with the following CPU operations.