[BugFix] MLA + V1, illegal memory access and accuracy issues (#14253)

Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
Lucas Wilkinson 2025-03-05 20:10:13 -05:00 committed by GitHub
parent 71eaf8969b
commit f6bb18fd9a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 326 additions and 153 deletions

View File

@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import inspect
from typing import Optional from typing import Optional
import numpy as np import numpy as np
@ -9,7 +10,8 @@ import torch
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.utils import is_pin_memory_available, make_tensor_with_pad
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import (BlockTable, CachedRequestState,
InputBatch)
VOCAB_SIZE = 1024 VOCAB_SIZE = 1024
NUM_OUTPUT_TOKENS = 20 NUM_OUTPUT_TOKENS = 20
@ -20,6 +22,34 @@ CUDA_DEVICES = [
MAX_NUM_PROMPT_TOKENS = 64 MAX_NUM_PROMPT_TOKENS = 64
def _compare_objs(obj1, obj2):
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
attr_names = set([
a[0] for a in attrs
if not (a[0].startswith('__') and a[0].endswith('__'))
])
for attr_name in attr_names:
a = getattr(obj1, attr_name)
b = getattr(obj2, attr_name)
is_same = False
if isinstance(a, torch.Tensor):
if (a.numel() == 0 or b.numel() == 0):
is_same = (a.numel() == 0 and b.numel() == 0)
elif torch.allclose(a, b):
is_same = True
elif isinstance(a, np.ndarray):
if np.allclose(a, b):
is_same = True
elif isinstance(a, (BlockTable, SamplingMetadata)):
_compare_objs(a, b)
is_same = True # if we make it here must be same
elif a == b:
is_same = True
assert is_same, f"Attribute {attr_name} is different"\
f" in {obj1} and {obj2}: {a} != {b}"
def _remove_requests( def _remove_requests(
input_batch: InputBatch, batch_size: int, input_batch: InputBatch, batch_size: int,
reqs: list[CachedRequestState]) -> tuple[set[str], list[int]]: reqs: list[CachedRequestState]) -> tuple[set[str], list[int]]:
@ -254,3 +284,61 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
assert torch.allclose( assert torch.allclose(
expected_sampling_metadata.allowed_token_ids_mask, expected_sampling_metadata.allowed_token_ids_mask,
sampling_metadata.allowed_token_ids_mask) sampling_metadata.allowed_token_ids_mask)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [32])
@pytest.mark.parametrize("swap_list", [((0, 1), )])
def test_swap_states_in_input_batch(device: str, batch_size: int,
swap_list: list):
"""
Tests the logic for managing sampling metadata in the InputBatch.
This test involves adding a set of requests to the InputBatch,
followed by removing a subset of them. Afterward, the batch is compacted,
and the `make_sampling_metadata` method is invoked on the batch. The
output of `make_sampling_metadata` is then compared against the expected
results to ensure correctness.
"""
input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_blocks_per_req=10,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
)
ref_input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_blocks_per_req=10,
device=torch.device(device),
pin_memory=is_pin_memory_available(),
vocab_size=1024,
)
reqs: list[CachedRequestState] = []
req_id_reqs = {}
req_id_output_token_ids = {}
# Add requests
for req_index in range(batch_size):
req: CachedRequestState = _construct_cached_request_state(req_index)
input_batch.add_request(req, req_index)
reqs.append(req)
req_id_reqs[req.req_id] = req
req_id_output_token_ids[req.req_id] = req.output_token_ids
reordered_reqs = reqs.copy()
for swap_pair in swap_list:
reordered_reqs[swap_pair[0]], reordered_reqs[swap_pair[1]] = \
reordered_reqs[swap_pair[1]], reordered_reqs[swap_pair[0]]
input_batch.swap_states(swap_pair[0], swap_pair[1])
for req_index in range(batch_size):
req = reordered_reqs[req_index]
ref_input_batch.add_request(req, req_index)
input_batch.refresh_sampling_metadata()
ref_input_batch.refresh_sampling_metadata()
_compare_objs(input_batch, ref_input_batch)

View File

@ -100,8 +100,8 @@ class FlashAttentionMetadataBuilder:
self.runner = runner self.runner = runner
def reorder_batch(self, input_batch: "InputBatch", def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput"): scheduler_output: "SchedulerOutput") -> bool:
pass return False
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int): common_prefix_len: int):

View File

@ -275,17 +275,47 @@ class MLACommonBackend(AttentionBackend):
@dataclass @dataclass
class MLACommonMetadata: class MLACommonPrefillMetadata:
""" Prefill Specific Metadata """
@dataclass
class ChunkedContextMetadata:
# New for MLA (compared to FlashAttention)
# For handling chunked prefill
cu_seq_lens: torch.Tensor
starts: torch.Tensor
seq_tot: list[int]
max_seq_lens: list[int]
workspace: torch.Tensor
# Input positions for rotrary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions: torch.Tensor
block_table: torch.Tensor
query_start_loc: torch.Tensor
max_query_len: int
chunked_context: Optional[ChunkedContextMetadata] = None
@dataclass
class MLACommonDecodeMetadata:
# Input positions for rotrary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions: torch.Tensor
block_table: torch.Tensor
seq_lens: torch.Tensor
D = TypeVar("D", bound=MLACommonDecodeMetadata)
@dataclass
class MLACommonMetadata(Generic[D]):
"""Metadata for MLACommon. """Metadata for MLACommon.
NOTE: Please read the comment at the top of the file before trying to NOTE: Please read the comment at the top of the file before trying to
understand this class understand this class
""" """
# New for MLA (compared to FlashAttention)
# Input positions for rotrary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions: torch.Tensor
# NOTE(sang): Definition of context_len, query_len, and seq_len. # NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------| # |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------| # |---------------- N iteration ---------------------|
@ -295,30 +325,23 @@ class MLACommonMetadata:
# |-- query_len ---| # |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding. num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
# New for MLA (compared to FlashAttention)
# For handling prefill decode split
num_decodes: int
num_decode_tokens: int
num_prefills: int
# For logging. # For logging.
num_input_tokens: int = 0 # Number of tokens including padding. num_input_tokens: int = 0 # Number of tokens including padding.
# The dimension of the attention heads # The dimension of the attention heads
head_dim: Optional[int] = None head_dim: Optional[int] = None
# New for MLA (compared to FlashAttention) decode: Optional[D] = None
# For chunked prefill prefill: Optional[MLACommonPrefillMetadata] = None
num_decodes: Optional[int] = None
num_decode_tokens: Optional[int] = None
num_prefills: Optional[int] = None
has_context: bool = False
context_chunk_cu_seq_lens: Optional[torch.Tensor] = None
context_chunk_starts: Optional[torch.Tensor] = None
context_chunk_seq_tot: Optional[list[int]] = None
context_chunk_max_seq_lens: Optional[list[int]] = None
chunked_prefill_workspace: Optional[torch.Tensor] = None
def __post_init__(self): def __post_init__(self):
supported_head_sizes = MLACommonBackend.get_supported_head_sizes() supported_head_sizes = MLACommonBackend.get_supported_head_sizes()
@ -329,10 +352,10 @@ class MLACommonMetadata:
f"received {self.head_dim}.") f"received {self.head_dim}.")
T = TypeVar("T", bound=MLACommonMetadata) M = TypeVar("M", bound=MLACommonMetadata)
class MLACommonMetadataBuilder(Generic[T]): class MLACommonMetadataBuilder(Generic[M]):
""" """
NOTE: Please read the comment at the top of the file before trying to NOTE: Please read the comment at the top of the file before trying to
understand this class understand this class
@ -340,8 +363,9 @@ class MLACommonMetadataBuilder(Generic[T]):
def __init__(self, def __init__(self,
runner: "GPUModelRunner", runner: "GPUModelRunner",
cls: Optional[type[T]] = None): metadata_cls: Optional[type[M]] = None):
self.cls = cls if cls is not None else MLACommonMetadata self.metadata_cls = metadata_cls \
if metadata_cls is not None else MLACommonMetadata
self.runner = runner self.runner = runner
scheduler_config = runner.scheduler_config scheduler_config = runner.scheduler_config
model_config = runner.model_config model_config = runner.model_config
@ -375,7 +399,7 @@ class MLACommonMetadataBuilder(Generic[T]):
self.page_size = self.runner.block_size self.page_size = self.runner.block_size
def reorder_batch(self, input_batch: "InputBatch", def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput"): scheduler_output: "SchedulerOutput") -> bool:
# We now want to reorder the batch so that the "decode" requests are and # We now want to reorder the batch so that the "decode" requests are and
# the front and the "prefill" requests are at the using the least amount # the front and the "prefill" requests are at the using the least amount
# swaps possible. (NOTE for now we loosely use "decode" to mean requests # swaps possible. (NOTE for now we loosely use "decode" to mean requests
@ -413,6 +437,7 @@ class MLACommonMetadataBuilder(Generic[T]):
num_decodes = len(decodes) num_decodes = len(decodes)
num_prefills = len(prefills) num_prefills = len(prefills)
first_prefill = 0 first_prefill = 0
modified_batch = False
for i in range(1, min(num_decodes, num_prefills) + 1): for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it # If the decode is at the "back" of the batch, i, we can swap it
@ -421,6 +446,7 @@ class MLACommonMetadataBuilder(Generic[T]):
input_batch.swap_states(prefills[first_prefill], input_batch.swap_states(prefills[first_prefill],
decodes[num_decodes - i]) decodes[num_decodes - i])
first_prefill += 1 first_prefill += 1
modified_batch = True
else: else:
break break
@ -432,10 +458,21 @@ class MLACommonMetadataBuilder(Generic[T]):
self._num_decode_tokens = num_decode_tokens self._num_decode_tokens = num_decode_tokens
self._num_prefill_tokens = num_prefill_tokens self._num_prefill_tokens = num_prefill_tokens
return modified_batch
def _build_decode(self, input_positions: torch.Tensor,
block_table: torch.Tensor, seq_lens: torch.Tensor):
return MLACommonDecodeMetadata(
input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
)
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int) -> T: common_prefix_len: int) -> M:
assert self._num_decodes + self._num_prefills == num_reqs
device = self.runner.device device = self.runner.device
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to( query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
device, non_blocking=True) device, non_blocking=True)
seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(device, seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(device,
@ -447,85 +484,103 @@ class MLACommonMetadataBuilder(Generic[T]):
input_positions = self.runner.positions_cpu[:num_actual_tokens].to( input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
device, non_blocking=True).long() device, non_blocking=True).long()
context_chunk_cu_seq_lens = None prefill_metadata = None
context_chunk_starts = None if self._num_prefills > 0:
context_chunk_seq_tot = None reqs_start = self._num_decodes # prefill_start
context_chunk_max_seq_lens = None tokens_start = self._num_decode_tokens
num_computed_tokens_cpu_tensor = \ context_lens_cpu = self.runner.input_batch.\
self.runner.input_batch.num_computed_tokens_cpu_tensor[:num_reqs] num_computed_tokens_cpu_tensor[reqs_start:num_reqs]
context_lens_tensor = \ context_lens = context_lens_cpu.to(device, non_blocking=True)
num_computed_tokens_cpu_tensor.to(device, non_blocking=True)
if self.chunked_prefill_enabled and self._num_prefills > 0 \ chunked_context_metadata = None
and context_lens_tensor[self._num_decodes:].max() > 0: if self.chunked_prefill_enabled and self._num_prefills > 0 \
# NOTE: it is recommend you read the `Chunked Prefill` section in and context_lens.max() > 0:
# the comment at the top of the file before trying to understand # NOTE: it is recommend you read the `Chunked Prefill` section
# the following code # in the comment at the top of the file before trying to
# understand the following code
self.has_context = True num_prefills_with_context = (context_lens > 0).sum().item()
num_prefills_with_context = \ # currently we allocate an equal amount of workspace for each
(context_lens_tensor[self._num_decodes:] > 0).sum().item() # prefill in the batch, we could probably use a more advanced
# algorithm here and allocate more workspace to prefills with
# longer context lengths
max_context_chunk = \
self.chunked_prefill_workspace_size \
// num_prefills_with_context
# currently we allocate an equal amount of workspace for each # align max_context_chunk to page_size by rounding down,
# prefill in the batch, we could probably use a more advanced # currently the `gather_cache` kernel cannot handle
# algorithm here and allocate more workspace to prefills with # `context_chunk_starts` that are not aligned to page_size
# longer context lengths max_context_chunk = round_down(max_context_chunk,
max_context_chunk = \ self.page_size)
self.chunked_prefill_workspace_size // num_prefills_with_context
# align max_context_chunk to page_size by rounding down, assert max_context_chunk > 0
# currently the `gather_cache` kernel cannot handle num_chunks = cdiv(context_lens.max(), max_context_chunk)
# `context_chunk_starts` that are not aligned to page_size
max_context_chunk = round_down(max_context_chunk, self.page_size)
assert max_context_chunk > 0
num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk)
# if `max_context_chunk = 256`, `num_chunks = 3`, and # if `max_context_chunk = 256`, `num_chunks = 3`, and
# `num_prefills_with_context = 4`, create a tensor that looks like # `num_prefills_with_context = 4`, create a tensor that looks
# [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]] # like
context_chunk_starts = \ # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
torch.arange(num_chunks, device=device, dtype=torch.int32) \ chunk_starts = \
.unsqueeze(1).expand(-1, self._num_prefills) \ torch.arange(num_chunks, device=device, dtype=torch.int32) \
* max_context_chunk .unsqueeze(1).expand(-1, self._num_prefills) \
chunk_ends = torch.min(context_lens_tensor[self._num_decodes:] \ * max_context_chunk
.unsqueeze(0), context_chunk_starts + max_context_chunk) chunk_ends = torch.min(context_lens.unsqueeze(0),
chunk_seq_lens = (chunk_ends - context_chunk_starts).clamp(min=0) chunk_starts + max_context_chunk)
_context_chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to( chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
torch.int32) _chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to(
zero = torch.zeros(num_chunks, dtype=torch.int32, device=device) \ torch.int32)
.unsqueeze(-1) zero = torch.zeros(num_chunks,
context_chunk_cu_seq_lens = \ dtype=torch.int32,
torch.cat([zero, _context_chunk_cu_seq_lens], dim=1) device=device).unsqueeze(-1)
context_chunk_max_seq_lens = \
chunk_seq_lens.max(dim=1).values.tolist()
context_chunk_seq_tot = chunk_seq_lens.sum(dim=1).tolist()
assert max(context_chunk_seq_tot) <= \
self.chunked_prefill_workspace_size
return self.cls( chunked_context_metadata = \
input_positions=input_positions, MLACommonPrefillMetadata.ChunkedContextMetadata(
cu_seq_lens=torch.cat(
[zero, _chunk_cu_seq_lens], dim=1),
starts=chunk_starts,
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
workspace=self.chunked_prefill_workspace,
)
assert max(chunked_context_metadata.max_seq_lens) <= \
self.chunked_prefill_workspace_size
prefill_metadata = MLACommonPrefillMetadata(
input_positions=input_positions[tokens_start:],
block_table=block_table[reqs_start:, ...],
query_start_loc=query_start_loc[reqs_start:] -
query_start_loc[reqs_start],
max_query_len=seq_lens[reqs_start:].max().item(),
chunked_context=chunked_context_metadata,
)
decode_metadata = None
if self._num_decodes > 0:
decode_metadata = self._build_decode(
input_positions=input_positions[:self._num_decode_tokens],
block_table=block_table[:self._num_decodes, ...],
seq_lens=seq_lens[:self._num_decodes],
)
return self.metadata_cls(
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
head_dim=self.runner.model_config.get_head_size(), head_dim=self.runner.model_config.get_head_size(),
# MLACommonMetadata Chunk prefill specific # MLACommonMetadata Chunk prefill specific
num_decodes=self._num_decodes, num_decodes=self._num_decodes,
num_decode_tokens=self._num_decode_tokens, num_decode_tokens=self._num_decode_tokens,
num_prefills=self._num_prefills, num_prefills=self._num_prefills,
context_chunk_cu_seq_lens=context_chunk_cu_seq_lens, prefill=prefill_metadata,
context_chunk_starts=context_chunk_starts, decode=decode_metadata,
context_chunk_seq_tot=context_chunk_seq_tot,
context_chunk_max_seq_lens=context_chunk_max_seq_lens,
) )
class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
""" """
NOTE: Please read the comment at the top of the file before trying to NOTE: Please read the comment at the top of the file before trying to
understand this class understand this class
@ -798,28 +853,24 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata, attn_metadata: MLACommonMetadata,
): ):
assert attn_metadata.num_prefills is not None assert attn_metadata.prefill is not None
assert attn_metadata.context_chunk_seq_tot is not None prefill_metadata = attn_metadata.prefill
assert attn_metadata.context_chunk_cu_seq_lens is not None assert prefill_metadata.chunked_context is not None
assert attn_metadata.context_chunk_starts is not None
assert attn_metadata.context_chunk_max_seq_lens is not None
output = None output = None
iters = len(attn_metadata.context_chunk_seq_tot) iters = len(prefill_metadata.chunked_context.seq_tot)
workspace = prefill_metadata.chunked_context.workspace
assert attn_metadata.chunked_prefill_workspace is not None
workspace = attn_metadata.chunked_prefill_workspace
for i in range(iters): for i in range(iters):
toks = attn_metadata.context_chunk_seq_tot[i] toks = prefill_metadata.chunked_context.seq_tot[i]
ops.gather_cache( ops.gather_cache(
src_cache=kv_c_and_k_pe_cache, src_cache=kv_c_and_k_pe_cache,
dst=workspace, dst=workspace,
block_table=attn_metadata.block_table, block_table=prefill_metadata.block_table,
cu_seq_lens=attn_metadata.context_chunk_cu_seq_lens[i], cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
batch_size=attn_metadata.num_prefills, batch_size=attn_metadata.num_prefills,
seq_starts=attn_metadata.context_chunk_starts[i], seq_starts=prefill_metadata.chunked_context.starts[i],
) )
kv_c_normed = workspace[:toks]\ kv_c_normed = workspace[:toks]\
@ -845,10 +896,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
q=q, q=q,
k=k, k=k,
v=v_padded, v=v_padded,
cu_seqlens_q=attn_metadata.query_start_loc, cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=attn_metadata.context_chunk_cu_seq_lens[i], cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i],
max_seqlen_q=attn_metadata.max_query_len, max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=attn_metadata.context_chunk_max_seq_lens[i], max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i],
softmax_scale=self.scale, softmax_scale=self.scale,
causal=False, # Context is unmasked causal=False, # Context is unmasked
return_softmax_lse=True, return_softmax_lse=True,
@ -881,7 +932,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata, attn_metadata: MLACommonMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
has_context = attn_metadata.has_context assert attn_metadata.prefill is not None
has_context = attn_metadata.prefill.chunked_context is not None
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\ kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\ k_nope, v = kv_nope\
@ -898,10 +951,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
q=q, q=q,
k=k, k=k,
v=v_padded, v=v_padded,
cu_seqlens_q=attn_metadata.query_start_loc, cu_seqlens_q=attn_metadata.prefill.query_start_loc,
cu_seqlens_k=attn_metadata.query_start_loc, cu_seqlens_k=attn_metadata.prefill.query_start_loc,
max_seqlen_q=attn_metadata.max_query_len, max_seqlen_q=attn_metadata.prefill.max_query_len,
max_seqlen_k=attn_metadata.max_seq_len, max_seqlen_k=attn_metadata.prefill.max_query_len,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
return_softmax_lse=has_context, return_softmax_lse=has_context,
@ -934,7 +987,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
q_nope: torch.Tensor, q_nope: torch.Tensor,
q_pe: torch.Tensor, q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: T, attn_metadata: M,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
@ -945,7 +998,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
k_c_normed: torch.Tensor, # key in unified attn k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor, kv_cache: torch.Tensor,
attn_metadata: T, attn_metadata: M,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
@ -966,7 +1019,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# Restore head dim (for rotary embedding) # Restore head dim (for rotary embedding)
k_pe = k_pe.unsqueeze(1) k_pe = k_pe.unsqueeze(1)
assert hasattr(attn_metadata, "input_positions")
assert attn_metadata.num_decodes is not None and \ assert attn_metadata.num_decodes is not None and \
attn_metadata.num_prefills is not None and \ attn_metadata.num_prefills is not None and \
@ -978,28 +1030,27 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens] decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
decode_k_pe = k_pe[:num_decode_tokens] decode_k_pe = k_pe[:num_decode_tokens]
decode_input_positions = \
attn_metadata.input_positions[:num_decode_tokens]
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:] prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
prefill_k_pe = k_pe[num_decode_tokens:] prefill_k_pe = k_pe[num_decode_tokens:]
prefill_input_positions = \
attn_metadata.input_positions[num_decode_tokens:]
prefill_k_c_normed = k_c_normed[num_decode_tokens:] prefill_k_c_normed = k_c_normed[num_decode_tokens:]
if has_decode: if has_decode:
assert attn_metadata.decode is not None
decode_q_nope = self._q_proj_and_k_up_proj(decode_hs_or_q_c) decode_q_nope = self._q_proj_and_k_up_proj(decode_hs_or_q_c)
decode_q_pe = torch.matmul(decode_hs_or_q_c, self.W_QR)\ decode_q_pe = torch.matmul(decode_hs_or_q_c, self.W_QR)\
.view(-1, self.num_heads, self.qk_rope_head_dim) .view(-1, self.num_heads, self.qk_rope_head_dim)
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
decode_input_positions, decode_q_pe, decode_k_pe) attn_metadata.decode.input_positions, decode_q_pe, decode_k_pe)
if has_prefill: if has_prefill:
assert attn_metadata.prefill is not None
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\ prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
.view(-1, self.num_heads, self.qk_head_dim) .view(-1, self.num_heads, self.qk_head_dim)
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:] prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb( prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
prefill_input_positions, prefill_q_pe, prefill_k_pe) attn_metadata.prefill.input_positions, prefill_q_pe,
prefill_k_pe)
# write the latent and rope to kv cache # write the latent and rope to kv cache
if kv_cache.numel() > 0: if kv_cache.numel() > 0:

View File

@ -11,6 +11,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
is_flashmla_supported) is_flashmla_supported)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend, from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl, MLACommonImpl,
MLACommonMetadata, MLACommonMetadata,
MLACommonMetadataBuilder) MLACommonMetadataBuilder)
@ -38,34 +39,41 @@ class FlashMLABackend(MLACommonBackend):
@dataclass @dataclass
class FlashMLAMetadata(MLACommonMetadata): class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
decode_tile_scheduler_metadata: Optional[tuple[torch.Tensor, tile_scheduler_metadata: tuple[torch.Tensor, torch.Tensor]
torch.Tensor]] = None num_splits: torch.Tensor
decode_num_splits: Optional[torch.Tensor] = None
@dataclass
class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
pass
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]): class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
def __init__(self, runner): def __init__(self, runner):
super().__init__(runner, cls=FlashMLAMetadata) super().__init__(runner)
self.num_q_heads = self.runner.model_config.get_num_attention_heads( self.num_q_heads = self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config) self.runner.parallel_config)
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, def _build_decode(self, input_positions: torch.Tensor,
common_prefix_len: int): block_table: torch.Tensor,
m = super().build(num_reqs, num_actual_tokens, max_query_len, seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
common_prefix_len) tile_scheduler_metadata, num_splits = \
get_mla_metadata(
seq_lens,
self.num_q_heads,
1, # MQA for the decode path
)
if m.num_decode_tokens is not None and m.num_decode_tokens > 0: return FlashMLADecodeMetadata(
m.decode_tile_scheduler_metadata, m.decode_num_splits = \ input_positions=input_positions,
get_mla_metadata( block_table=block_table,
m.seq_lens[:m.num_decode_tokens], seq_lens=seq_lens,
self.num_q_heads, tile_scheduler_metadata=tile_scheduler_metadata,
1, # MQA for the decode path num_splits=num_splits,
) )
return m
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
@ -115,6 +123,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
attn_metadata: FlashMLAMetadata, attn_metadata: FlashMLAMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0 assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if self.kv_cache_dtype.startswith("fp8"): if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 FlashMLA not yet supported") raise NotImplementedError("FP8 FlashMLA not yet supported")
@ -124,14 +134,12 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
o, _ = flash_mla_with_kvcache( o, _ = flash_mla_with_kvcache(
q=q, q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.block_table[:attn_metadata.num_decodes, block_table=attn_metadata.decode.block_table,
...], cache_seqlens=attn_metadata.decode.seq_lens,
cache_seqlens=attn_metadata.seq_lens[:attn_metadata.
num_decode_tokens],
head_dim_v=self.kv_lora_rank, head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata. tile_scheduler_metadata=attn_metadata.decode.
decode_tile_scheduler_metadata, tile_scheduler_metadata,
num_splits=attn_metadata.decode_num_splits, num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
) )

View File

@ -69,6 +69,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
attn_metadata: MLACommonMetadata, attn_metadata: MLACommonMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0 assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if self.kv_cache_dtype.startswith("fp8"): if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Triton MLA not yet supported") raise NotImplementedError("FP8 Triton MLA not yet supported")
@ -104,7 +106,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
# Run MQA # Run MQA
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o, decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
attn_metadata.block_table, attn_metadata.seq_lens, attn_metadata.decode.block_table,
attn_logits, num_kv_splits, self.scale, PAGE_SIZE) attn_metadata.decode.seq_lens, attn_logits,
num_kv_splits, self.scale, PAGE_SIZE)
return self._v_up_proj_and_o_proj(o) return self._v_up_proj_and_o_proj(o)

View File

@ -383,8 +383,6 @@ class InputBatch:
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1] self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
self.num_tokens[i1], self.num_tokens[i2] =\ self.num_tokens[i1], self.num_tokens[i2] =\
self.num_tokens[i2], self.num_tokens[i1] self.num_tokens[i2], self.num_tokens[i1]
self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\ self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1] self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\ self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
@ -406,24 +404,47 @@ class InputBatch:
self.min_p_cpu[i1], self.min_p_cpu[i2] =\ self.min_p_cpu[i1], self.min_p_cpu[i2] =\
self.min_p_cpu[i2], self.min_p_cpu[i1] self.min_p_cpu[i2], self.min_p_cpu[i1]
# NOTE: the following is unsafe
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
# instead, we need to temporiarily copy the data for one of the indices
# TODO(lucas): optimize this by only copying valid indices
tmp = self.token_ids_cpu[i1, ...].copy()
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
self.token_ids_cpu[i2, ...] = tmp
g1 = self.generators.get(i1) g1 = self.generators.get(i1)
g2 = self.generators.get(i2) g2 = self.generators.get(i2)
if g1 is not None: if g1 is not None:
self.generators[i2] = g1 self.generators[i2] = g1
else:
self.generators.pop(i2, None)
if g2 is not None: if g2 is not None:
self.generators[i1] = g2 self.generators[i1] = g2
else:
self.generators.pop(i1, None)
t1 = self.min_tokens.get(i1) t1 = self.min_tokens.get(i1)
t2 = self.min_tokens.get(i2) t2 = self.min_tokens.get(i2)
if t1 is not None: if t1 is not None:
self.min_tokens[i2] = t1 self.min_tokens[i2] = t1
else:
self.min_tokens.pop(i2, None)
if t2 is not None: if t2 is not None:
self.min_tokens[i1] = t2 self.min_tokens[i1] = t2
else:
self.min_tokens.pop(i1, None)
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\ self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
self.request_lora_mapping[i2], self.request_lora_mapping[i1] self.request_lora_mapping[i2], self.request_lora_mapping[i1]
self.logit_bias[i1], self.logit_bias[i2] =\ self.logit_bias[i1], self.logit_bias[i2] =\
self.logit_bias[i2], self.logit_bias[i1] self.logit_bias[i2], self.logit_bias[i1]
if self.allowed_token_ids_mask_cpu_tensor is not None:
self.allowed_token_ids_mask_cpu_tensor[i1], \
self.allowed_token_ids_mask_cpu_tensor[i2] =\
self.allowed_token_ids_mask_cpu_tensor[i2], \
self.allowed_token_ids_mask_cpu_tensor[i1]
self.block_table.swap_row(i1, i2) self.block_table.swap_row(i1, i2)
def condense(self, empty_req_indices: list[int]) -> None: def condense(self, empty_req_indices: list[int]) -> None:

View File

@ -456,8 +456,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Some attention backends (namely MLA) may want to separate requests # Some attention backends (namely MLA) may want to separate requests
# based on if the attention computation will be compute-bound or # based on if the attention computation will be compute-bound or
# memory-bound. This gives them a hook to do that. # memory-bound. This gives them a hook to do that.
self.attn_metadata_builder.reorder_batch(self.input_batch, modified_batch = self.attn_metadata_builder.reorder_batch(
scheduler_output) self.input_batch, scheduler_output)
if modified_batch:
self.input_batch.refresh_sampling_metadata()
# OPTIMIZATION: Start copying the block table first. # OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations. # This way, we can overlap the copy with the following CPU operations.