[BugFix] MLA + V1, illegal memory access and accuracy issues (#14253)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
This commit is contained in:
parent
71eaf8969b
commit
f6bb18fd9a
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import inspect
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
@ -9,7 +10,8 @@ import torch
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.gpu_input_batch import (BlockTable, CachedRequestState,
|
||||
InputBatch)
|
||||
|
||||
VOCAB_SIZE = 1024
|
||||
NUM_OUTPUT_TOKENS = 20
|
||||
@ -20,6 +22,34 @@ CUDA_DEVICES = [
|
||||
MAX_NUM_PROMPT_TOKENS = 64
|
||||
|
||||
|
||||
def _compare_objs(obj1, obj2):
|
||||
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
|
||||
attr_names = set([
|
||||
a[0] for a in attrs
|
||||
if not (a[0].startswith('__') and a[0].endswith('__'))
|
||||
])
|
||||
for attr_name in attr_names:
|
||||
a = getattr(obj1, attr_name)
|
||||
b = getattr(obj2, attr_name)
|
||||
|
||||
is_same = False
|
||||
if isinstance(a, torch.Tensor):
|
||||
if (a.numel() == 0 or b.numel() == 0):
|
||||
is_same = (a.numel() == 0 and b.numel() == 0)
|
||||
elif torch.allclose(a, b):
|
||||
is_same = True
|
||||
elif isinstance(a, np.ndarray):
|
||||
if np.allclose(a, b):
|
||||
is_same = True
|
||||
elif isinstance(a, (BlockTable, SamplingMetadata)):
|
||||
_compare_objs(a, b)
|
||||
is_same = True # if we make it here must be same
|
||||
elif a == b:
|
||||
is_same = True
|
||||
assert is_same, f"Attribute {attr_name} is different"\
|
||||
f" in {obj1} and {obj2}: {a} != {b}"
|
||||
|
||||
|
||||
def _remove_requests(
|
||||
input_batch: InputBatch, batch_size: int,
|
||||
reqs: list[CachedRequestState]) -> tuple[set[str], list[int]]:
|
||||
@ -254,3 +284,61 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
||||
assert torch.allclose(
|
||||
expected_sampling_metadata.allowed_token_ids_mask,
|
||||
sampling_metadata.allowed_token_ids_mask)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("batch_size", [32])
|
||||
@pytest.mark.parametrize("swap_list", [((0, 1), )])
|
||||
def test_swap_states_in_input_batch(device: str, batch_size: int,
|
||||
swap_list: list):
|
||||
"""
|
||||
Tests the logic for managing sampling metadata in the InputBatch.
|
||||
|
||||
This test involves adding a set of requests to the InputBatch,
|
||||
followed by removing a subset of them. Afterward, the batch is compacted,
|
||||
and the `make_sampling_metadata` method is invoked on the batch. The
|
||||
output of `make_sampling_metadata` is then compared against the expected
|
||||
results to ensure correctness.
|
||||
"""
|
||||
input_batch: InputBatch = InputBatch(
|
||||
max_num_reqs=batch_size,
|
||||
max_model_len=1024,
|
||||
max_num_blocks_per_req=10,
|
||||
device=torch.device(device),
|
||||
pin_memory=is_pin_memory_available(),
|
||||
vocab_size=1024,
|
||||
)
|
||||
ref_input_batch: InputBatch = InputBatch(
|
||||
max_num_reqs=batch_size,
|
||||
max_model_len=1024,
|
||||
max_num_blocks_per_req=10,
|
||||
device=torch.device(device),
|
||||
pin_memory=is_pin_memory_available(),
|
||||
vocab_size=1024,
|
||||
)
|
||||
|
||||
reqs: list[CachedRequestState] = []
|
||||
req_id_reqs = {}
|
||||
req_id_output_token_ids = {}
|
||||
# Add requests
|
||||
for req_index in range(batch_size):
|
||||
req: CachedRequestState = _construct_cached_request_state(req_index)
|
||||
input_batch.add_request(req, req_index)
|
||||
reqs.append(req)
|
||||
req_id_reqs[req.req_id] = req
|
||||
req_id_output_token_ids[req.req_id] = req.output_token_ids
|
||||
|
||||
reordered_reqs = reqs.copy()
|
||||
for swap_pair in swap_list:
|
||||
reordered_reqs[swap_pair[0]], reordered_reqs[swap_pair[1]] = \
|
||||
reordered_reqs[swap_pair[1]], reordered_reqs[swap_pair[0]]
|
||||
input_batch.swap_states(swap_pair[0], swap_pair[1])
|
||||
|
||||
for req_index in range(batch_size):
|
||||
req = reordered_reqs[req_index]
|
||||
ref_input_batch.add_request(req, req_index)
|
||||
|
||||
input_batch.refresh_sampling_metadata()
|
||||
ref_input_batch.refresh_sampling_metadata()
|
||||
|
||||
_compare_objs(input_batch, ref_input_batch)
|
||||
|
@ -100,8 +100,8 @@ class FlashAttentionMetadataBuilder:
|
||||
self.runner = runner
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput"):
|
||||
pass
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
return False
|
||||
|
||||
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
||||
common_prefix_len: int):
|
||||
|
@ -275,17 +275,47 @@ class MLACommonBackend(AttentionBackend):
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLACommonMetadata:
|
||||
class MLACommonPrefillMetadata:
|
||||
""" Prefill Specific Metadata """
|
||||
|
||||
@dataclass
|
||||
class ChunkedContextMetadata:
|
||||
# New for MLA (compared to FlashAttention)
|
||||
# For handling chunked prefill
|
||||
cu_seq_lens: torch.Tensor
|
||||
starts: torch.Tensor
|
||||
seq_tot: list[int]
|
||||
max_seq_lens: list[int]
|
||||
workspace: torch.Tensor
|
||||
|
||||
# Input positions for rotrary embeddings since for MLA the rotary
|
||||
# position embeddings are applied inside the attention backend
|
||||
input_positions: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
query_start_loc: torch.Tensor
|
||||
max_query_len: int
|
||||
chunked_context: Optional[ChunkedContextMetadata] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLACommonDecodeMetadata:
|
||||
# Input positions for rotrary embeddings since for MLA the rotary
|
||||
# position embeddings are applied inside the attention backend
|
||||
input_positions: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
|
||||
D = TypeVar("D", bound=MLACommonDecodeMetadata)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLACommonMetadata(Generic[D]):
|
||||
"""Metadata for MLACommon.
|
||||
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
# New for MLA (compared to FlashAttention)
|
||||
# Input positions for rotrary embeddings since for MLA the rotary
|
||||
# position embeddings are applied inside the attention backend
|
||||
input_positions: torch.Tensor
|
||||
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
@ -295,30 +325,23 @@ class MLACommonMetadata:
|
||||
# |-- query_len ---|
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
# New for MLA (compared to FlashAttention)
|
||||
# For handling prefill decode split
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_prefills: int
|
||||
|
||||
# For logging.
|
||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||
|
||||
# The dimension of the attention heads
|
||||
head_dim: Optional[int] = None
|
||||
|
||||
# New for MLA (compared to FlashAttention)
|
||||
# For chunked prefill
|
||||
num_decodes: Optional[int] = None
|
||||
num_decode_tokens: Optional[int] = None
|
||||
num_prefills: Optional[int] = None
|
||||
has_context: bool = False
|
||||
context_chunk_cu_seq_lens: Optional[torch.Tensor] = None
|
||||
context_chunk_starts: Optional[torch.Tensor] = None
|
||||
context_chunk_seq_tot: Optional[list[int]] = None
|
||||
context_chunk_max_seq_lens: Optional[list[int]] = None
|
||||
chunked_prefill_workspace: Optional[torch.Tensor] = None
|
||||
decode: Optional[D] = None
|
||||
prefill: Optional[MLACommonPrefillMetadata] = None
|
||||
|
||||
def __post_init__(self):
|
||||
supported_head_sizes = MLACommonBackend.get_supported_head_sizes()
|
||||
@ -329,10 +352,10 @@ class MLACommonMetadata:
|
||||
f"received {self.head_dim}.")
|
||||
|
||||
|
||||
T = TypeVar("T", bound=MLACommonMetadata)
|
||||
M = TypeVar("M", bound=MLACommonMetadata)
|
||||
|
||||
|
||||
class MLACommonMetadataBuilder(Generic[T]):
|
||||
class MLACommonMetadataBuilder(Generic[M]):
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
@ -340,8 +363,9 @@ class MLACommonMetadataBuilder(Generic[T]):
|
||||
|
||||
def __init__(self,
|
||||
runner: "GPUModelRunner",
|
||||
cls: Optional[type[T]] = None):
|
||||
self.cls = cls if cls is not None else MLACommonMetadata
|
||||
metadata_cls: Optional[type[M]] = None):
|
||||
self.metadata_cls = metadata_cls \
|
||||
if metadata_cls is not None else MLACommonMetadata
|
||||
self.runner = runner
|
||||
scheduler_config = runner.scheduler_config
|
||||
model_config = runner.model_config
|
||||
@ -375,7 +399,7 @@ class MLACommonMetadataBuilder(Generic[T]):
|
||||
self.page_size = self.runner.block_size
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput"):
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
# We now want to reorder the batch so that the "decode" requests are and
|
||||
# the front and the "prefill" requests are at the using the least amount
|
||||
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
|
||||
@ -413,6 +437,7 @@ class MLACommonMetadataBuilder(Generic[T]):
|
||||
num_decodes = len(decodes)
|
||||
num_prefills = len(prefills)
|
||||
first_prefill = 0
|
||||
modified_batch = False
|
||||
|
||||
for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||
# If the decode is at the "back" of the batch, i, we can swap it
|
||||
@ -421,6 +446,7 @@ class MLACommonMetadataBuilder(Generic[T]):
|
||||
input_batch.swap_states(prefills[first_prefill],
|
||||
decodes[num_decodes - i])
|
||||
first_prefill += 1
|
||||
modified_batch = True
|
||||
else:
|
||||
break
|
||||
|
||||
@ -432,10 +458,21 @@ class MLACommonMetadataBuilder(Generic[T]):
|
||||
self._num_decode_tokens = num_decode_tokens
|
||||
self._num_prefill_tokens = num_prefill_tokens
|
||||
|
||||
return modified_batch
|
||||
|
||||
def _build_decode(self, input_positions: torch.Tensor,
|
||||
block_table: torch.Tensor, seq_lens: torch.Tensor):
|
||||
return MLACommonDecodeMetadata(
|
||||
input_positions=input_positions,
|
||||
block_table=block_table,
|
||||
seq_lens=seq_lens,
|
||||
)
|
||||
|
||||
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
||||
common_prefix_len: int) -> T:
|
||||
common_prefix_len: int) -> M:
|
||||
assert self._num_decodes + self._num_prefills == num_reqs
|
||||
|
||||
device = self.runner.device
|
||||
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
|
||||
query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
|
||||
device, non_blocking=True)
|
||||
seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(device,
|
||||
@ -447,85 +484,103 @@ class MLACommonMetadataBuilder(Generic[T]):
|
||||
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
|
||||
device, non_blocking=True).long()
|
||||
|
||||
context_chunk_cu_seq_lens = None
|
||||
context_chunk_starts = None
|
||||
context_chunk_seq_tot = None
|
||||
context_chunk_max_seq_lens = None
|
||||
prefill_metadata = None
|
||||
if self._num_prefills > 0:
|
||||
reqs_start = self._num_decodes # prefill_start
|
||||
tokens_start = self._num_decode_tokens
|
||||
|
||||
num_computed_tokens_cpu_tensor = \
|
||||
self.runner.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]
|
||||
context_lens_tensor = \
|
||||
num_computed_tokens_cpu_tensor.to(device, non_blocking=True)
|
||||
context_lens_cpu = self.runner.input_batch.\
|
||||
num_computed_tokens_cpu_tensor[reqs_start:num_reqs]
|
||||
context_lens = context_lens_cpu.to(device, non_blocking=True)
|
||||
|
||||
if self.chunked_prefill_enabled and self._num_prefills > 0 \
|
||||
and context_lens_tensor[self._num_decodes:].max() > 0:
|
||||
# NOTE: it is recommend you read the `Chunked Prefill` section in
|
||||
# the comment at the top of the file before trying to understand
|
||||
# the following code
|
||||
chunked_context_metadata = None
|
||||
if self.chunked_prefill_enabled and self._num_prefills > 0 \
|
||||
and context_lens.max() > 0:
|
||||
# NOTE: it is recommend you read the `Chunked Prefill` section
|
||||
# in the comment at the top of the file before trying to
|
||||
# understand the following code
|
||||
|
||||
self.has_context = True
|
||||
num_prefills_with_context = (context_lens > 0).sum().item()
|
||||
|
||||
num_prefills_with_context = \
|
||||
(context_lens_tensor[self._num_decodes:] > 0).sum().item()
|
||||
# currently we allocate an equal amount of workspace for each
|
||||
# prefill in the batch, we could probably use a more advanced
|
||||
# algorithm here and allocate more workspace to prefills with
|
||||
# longer context lengths
|
||||
max_context_chunk = \
|
||||
self.chunked_prefill_workspace_size \
|
||||
// num_prefills_with_context
|
||||
|
||||
# currently we allocate an equal amount of workspace for each
|
||||
# prefill in the batch, we could probably use a more advanced
|
||||
# algorithm here and allocate more workspace to prefills with
|
||||
# longer context lengths
|
||||
max_context_chunk = \
|
||||
self.chunked_prefill_workspace_size // num_prefills_with_context
|
||||
# align max_context_chunk to page_size by rounding down,
|
||||
# currently the `gather_cache` kernel cannot handle
|
||||
# `context_chunk_starts` that are not aligned to page_size
|
||||
max_context_chunk = round_down(max_context_chunk,
|
||||
self.page_size)
|
||||
|
||||
# align max_context_chunk to page_size by rounding down,
|
||||
# currently the `gather_cache` kernel cannot handle
|
||||
# `context_chunk_starts` that are not aligned to page_size
|
||||
max_context_chunk = round_down(max_context_chunk, self.page_size)
|
||||
assert max_context_chunk > 0
|
||||
num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk)
|
||||
assert max_context_chunk > 0
|
||||
num_chunks = cdiv(context_lens.max(), max_context_chunk)
|
||||
|
||||
# if `max_context_chunk = 256`, `num_chunks = 3`, and
|
||||
# `num_prefills_with_context = 4`, create a tensor that looks like
|
||||
# [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
|
||||
context_chunk_starts = \
|
||||
torch.arange(num_chunks, device=device, dtype=torch.int32) \
|
||||
.unsqueeze(1).expand(-1, self._num_prefills) \
|
||||
* max_context_chunk
|
||||
chunk_ends = torch.min(context_lens_tensor[self._num_decodes:] \
|
||||
.unsqueeze(0), context_chunk_starts + max_context_chunk)
|
||||
chunk_seq_lens = (chunk_ends - context_chunk_starts).clamp(min=0)
|
||||
_context_chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to(
|
||||
torch.int32)
|
||||
zero = torch.zeros(num_chunks, dtype=torch.int32, device=device) \
|
||||
.unsqueeze(-1)
|
||||
context_chunk_cu_seq_lens = \
|
||||
torch.cat([zero, _context_chunk_cu_seq_lens], dim=1)
|
||||
context_chunk_max_seq_lens = \
|
||||
chunk_seq_lens.max(dim=1).values.tolist()
|
||||
context_chunk_seq_tot = chunk_seq_lens.sum(dim=1).tolist()
|
||||
assert max(context_chunk_seq_tot) <= \
|
||||
self.chunked_prefill_workspace_size
|
||||
# if `max_context_chunk = 256`, `num_chunks = 3`, and
|
||||
# `num_prefills_with_context = 4`, create a tensor that looks
|
||||
# like
|
||||
# [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
|
||||
chunk_starts = \
|
||||
torch.arange(num_chunks, device=device, dtype=torch.int32) \
|
||||
.unsqueeze(1).expand(-1, self._num_prefills) \
|
||||
* max_context_chunk
|
||||
chunk_ends = torch.min(context_lens.unsqueeze(0),
|
||||
chunk_starts + max_context_chunk)
|
||||
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
|
||||
_chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to(
|
||||
torch.int32)
|
||||
zero = torch.zeros(num_chunks,
|
||||
dtype=torch.int32,
|
||||
device=device).unsqueeze(-1)
|
||||
|
||||
return self.cls(
|
||||
input_positions=input_positions,
|
||||
chunked_context_metadata = \
|
||||
MLACommonPrefillMetadata.ChunkedContextMetadata(
|
||||
cu_seq_lens=torch.cat(
|
||||
[zero, _chunk_cu_seq_lens], dim=1),
|
||||
starts=chunk_starts,
|
||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
)
|
||||
|
||||
assert max(chunked_context_metadata.max_seq_lens) <= \
|
||||
self.chunked_prefill_workspace_size
|
||||
|
||||
prefill_metadata = MLACommonPrefillMetadata(
|
||||
input_positions=input_positions[tokens_start:],
|
||||
block_table=block_table[reqs_start:, ...],
|
||||
query_start_loc=query_start_loc[reqs_start:] -
|
||||
query_start_loc[reqs_start],
|
||||
max_query_len=seq_lens[reqs_start:].max().item(),
|
||||
chunked_context=chunked_context_metadata,
|
||||
)
|
||||
|
||||
decode_metadata = None
|
||||
if self._num_decodes > 0:
|
||||
decode_metadata = self._build_decode(
|
||||
input_positions=input_positions[:self._num_decode_tokens],
|
||||
block_table=block_table[:self._num_decodes, ...],
|
||||
seq_lens=seq_lens[:self._num_decodes],
|
||||
)
|
||||
|
||||
return self.metadata_cls(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table,
|
||||
slot_mapping=slot_mapping,
|
||||
head_dim=self.runner.model_config.get_head_size(),
|
||||
# MLACommonMetadata Chunk prefill specific
|
||||
num_decodes=self._num_decodes,
|
||||
num_decode_tokens=self._num_decode_tokens,
|
||||
num_prefills=self._num_prefills,
|
||||
context_chunk_cu_seq_lens=context_chunk_cu_seq_lens,
|
||||
context_chunk_starts=context_chunk_starts,
|
||||
context_chunk_seq_tot=context_chunk_seq_tot,
|
||||
context_chunk_max_seq_lens=context_chunk_max_seq_lens,
|
||||
prefill=prefill_metadata,
|
||||
decode=decode_metadata,
|
||||
)
|
||||
|
||||
|
||||
class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
@ -798,28 +853,24 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
):
|
||||
assert attn_metadata.num_prefills is not None
|
||||
assert attn_metadata.context_chunk_seq_tot is not None
|
||||
assert attn_metadata.context_chunk_cu_seq_lens is not None
|
||||
assert attn_metadata.context_chunk_starts is not None
|
||||
assert attn_metadata.context_chunk_max_seq_lens is not None
|
||||
assert attn_metadata.prefill is not None
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
assert prefill_metadata.chunked_context is not None
|
||||
|
||||
output = None
|
||||
iters = len(attn_metadata.context_chunk_seq_tot)
|
||||
|
||||
assert attn_metadata.chunked_prefill_workspace is not None
|
||||
workspace = attn_metadata.chunked_prefill_workspace
|
||||
iters = len(prefill_metadata.chunked_context.seq_tot)
|
||||
workspace = prefill_metadata.chunked_context.workspace
|
||||
|
||||
for i in range(iters):
|
||||
toks = attn_metadata.context_chunk_seq_tot[i]
|
||||
toks = prefill_metadata.chunked_context.seq_tot[i]
|
||||
|
||||
ops.gather_cache(
|
||||
src_cache=kv_c_and_k_pe_cache,
|
||||
dst=workspace,
|
||||
block_table=attn_metadata.block_table,
|
||||
cu_seq_lens=attn_metadata.context_chunk_cu_seq_lens[i],
|
||||
block_table=prefill_metadata.block_table,
|
||||
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
|
||||
batch_size=attn_metadata.num_prefills,
|
||||
seq_starts=attn_metadata.context_chunk_starts[i],
|
||||
seq_starts=prefill_metadata.chunked_context.starts[i],
|
||||
)
|
||||
|
||||
kv_c_normed = workspace[:toks]\
|
||||
@ -845,10 +896,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
q=q,
|
||||
k=k,
|
||||
v=v_padded,
|
||||
cu_seqlens_q=attn_metadata.query_start_loc,
|
||||
cu_seqlens_k=attn_metadata.context_chunk_cu_seq_lens[i],
|
||||
max_seqlen_q=attn_metadata.max_query_len,
|
||||
max_seqlen_k=attn_metadata.context_chunk_max_seq_lens[i],
|
||||
cu_seqlens_q=prefill_metadata.query_start_loc,
|
||||
cu_seqlens_k=prefill_metadata.chunked_context.cu_seq_lens[i],
|
||||
max_seqlen_q=prefill_metadata.max_query_len,
|
||||
max_seqlen_k=prefill_metadata.chunked_context.max_seq_lens[i],
|
||||
softmax_scale=self.scale,
|
||||
causal=False, # Context is unmasked
|
||||
return_softmax_lse=True,
|
||||
@ -881,7 +932,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
) -> torch.Tensor:
|
||||
has_context = attn_metadata.has_context
|
||||
assert attn_metadata.prefill is not None
|
||||
|
||||
has_context = attn_metadata.prefill.chunked_context is not None
|
||||
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
|
||||
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope, v = kv_nope\
|
||||
@ -898,10 +951,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
q=q,
|
||||
k=k,
|
||||
v=v_padded,
|
||||
cu_seqlens_q=attn_metadata.query_start_loc,
|
||||
cu_seqlens_k=attn_metadata.query_start_loc,
|
||||
max_seqlen_q=attn_metadata.max_query_len,
|
||||
max_seqlen_k=attn_metadata.max_seq_len,
|
||||
cu_seqlens_q=attn_metadata.prefill.query_start_loc,
|
||||
cu_seqlens_k=attn_metadata.prefill.query_start_loc,
|
||||
max_seqlen_q=attn_metadata.prefill.max_query_len,
|
||||
max_seqlen_k=attn_metadata.prefill.max_query_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
return_softmax_lse=has_context,
|
||||
@ -934,7 +987,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
attn_metadata: M,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -945,7 +998,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
k_c_normed: torch.Tensor, # key in unified attn
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
attn_metadata: M,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
@ -966,7 +1019,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
|
||||
# Restore head dim (for rotary embedding)
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
assert hasattr(attn_metadata, "input_positions")
|
||||
|
||||
assert attn_metadata.num_decodes is not None and \
|
||||
attn_metadata.num_prefills is not None and \
|
||||
@ -978,28 +1030,27 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
||||
|
||||
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
|
||||
decode_k_pe = k_pe[:num_decode_tokens]
|
||||
decode_input_positions = \
|
||||
attn_metadata.input_positions[:num_decode_tokens]
|
||||
|
||||
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
|
||||
prefill_k_pe = k_pe[num_decode_tokens:]
|
||||
prefill_input_positions = \
|
||||
attn_metadata.input_positions[num_decode_tokens:]
|
||||
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
|
||||
|
||||
if has_decode:
|
||||
assert attn_metadata.decode is not None
|
||||
decode_q_nope = self._q_proj_and_k_up_proj(decode_hs_or_q_c)
|
||||
decode_q_pe = torch.matmul(decode_hs_or_q_c, self.W_QR)\
|
||||
.view(-1, self.num_heads, self.qk_rope_head_dim)
|
||||
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
|
||||
decode_input_positions, decode_q_pe, decode_k_pe)
|
||||
attn_metadata.decode.input_positions, decode_q_pe, decode_k_pe)
|
||||
|
||||
if has_prefill:
|
||||
assert attn_metadata.prefill is not None
|
||||
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)
|
||||
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
|
||||
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
|
||||
prefill_input_positions, prefill_q_pe, prefill_k_pe)
|
||||
attn_metadata.prefill.input_positions, prefill_q_pe,
|
||||
prefill_k_pe)
|
||||
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
|
@ -11,6 +11,7 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
|
||||
is_flashmla_supported)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder)
|
||||
@ -38,34 +39,41 @@ class FlashMLABackend(MLACommonBackend):
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLAMetadata(MLACommonMetadata):
|
||||
decode_tile_scheduler_metadata: Optional[tuple[torch.Tensor,
|
||||
torch.Tensor]] = None
|
||||
decode_num_splits: Optional[torch.Tensor] = None
|
||||
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
tile_scheduler_metadata: tuple[torch.Tensor, torch.Tensor]
|
||||
num_splits: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
|
||||
pass
|
||||
|
||||
|
||||
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
|
||||
def __init__(self, runner):
|
||||
super().__init__(runner, cls=FlashMLAMetadata)
|
||||
super().__init__(runner)
|
||||
|
||||
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
|
||||
self.runner.parallel_config)
|
||||
|
||||
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
||||
common_prefix_len: int):
|
||||
m = super().build(num_reqs, num_actual_tokens, max_query_len,
|
||||
common_prefix_len)
|
||||
def _build_decode(self, input_positions: torch.Tensor,
|
||||
block_table: torch.Tensor,
|
||||
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
|
||||
tile_scheduler_metadata, num_splits = \
|
||||
get_mla_metadata(
|
||||
seq_lens,
|
||||
self.num_q_heads,
|
||||
1, # MQA for the decode path
|
||||
)
|
||||
|
||||
if m.num_decode_tokens is not None and m.num_decode_tokens > 0:
|
||||
m.decode_tile_scheduler_metadata, m.decode_num_splits = \
|
||||
get_mla_metadata(
|
||||
m.seq_lens[:m.num_decode_tokens],
|
||||
self.num_q_heads,
|
||||
1, # MQA for the decode path
|
||||
)
|
||||
|
||||
return m
|
||||
return FlashMLADecodeMetadata(
|
||||
input_positions=input_positions,
|
||||
block_table=block_table,
|
||||
seq_lens=seq_lens,
|
||||
tile_scheduler_metadata=tile_scheduler_metadata,
|
||||
num_splits=num_splits,
|
||||
)
|
||||
|
||||
|
||||
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
@ -115,6 +123,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
attn_metadata: FlashMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 FlashMLA not yet supported")
|
||||
|
||||
@ -124,14 +134,12 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
o, _ = flash_mla_with_kvcache(
|
||||
q=q,
|
||||
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||
block_table=attn_metadata.block_table[:attn_metadata.num_decodes,
|
||||
...],
|
||||
cache_seqlens=attn_metadata.seq_lens[:attn_metadata.
|
||||
num_decode_tokens],
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
cache_seqlens=attn_metadata.decode.seq_lens,
|
||||
head_dim_v=self.kv_lora_rank,
|
||||
tile_scheduler_metadata=attn_metadata.
|
||||
decode_tile_scheduler_metadata,
|
||||
num_splits=attn_metadata.decode_num_splits,
|
||||
tile_scheduler_metadata=attn_metadata.decode.
|
||||
tile_scheduler_metadata,
|
||||
num_splits=attn_metadata.decode.num_splits,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
)
|
||||
|
@ -69,6 +69,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
attn_metadata: MLACommonMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 Triton MLA not yet supported")
|
||||
|
||||
@ -104,7 +106,8 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
|
||||
# Run MQA
|
||||
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
|
||||
attn_metadata.block_table, attn_metadata.seq_lens,
|
||||
attn_logits, num_kv_splits, self.scale, PAGE_SIZE)
|
||||
attn_metadata.decode.block_table,
|
||||
attn_metadata.decode.seq_lens, attn_logits,
|
||||
num_kv_splits, self.scale, PAGE_SIZE)
|
||||
|
||||
return self._v_up_proj_and_o_proj(o)
|
||||
|
@ -383,8 +383,6 @@ class InputBatch:
|
||||
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
|
||||
self.num_tokens[i1], self.num_tokens[i2] =\
|
||||
self.num_tokens[i2], self.num_tokens[i1]
|
||||
self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
|
||||
self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
|
||||
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
|
||||
self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
|
||||
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
|
||||
@ -406,24 +404,47 @@ class InputBatch:
|
||||
self.min_p_cpu[i1], self.min_p_cpu[i2] =\
|
||||
self.min_p_cpu[i2], self.min_p_cpu[i1]
|
||||
|
||||
# NOTE: the following is unsafe
|
||||
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
|
||||
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
|
||||
# instead, we need to temporiarily copy the data for one of the indices
|
||||
# TODO(lucas): optimize this by only copying valid indices
|
||||
tmp = self.token_ids_cpu[i1, ...].copy()
|
||||
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
|
||||
self.token_ids_cpu[i2, ...] = tmp
|
||||
|
||||
g1 = self.generators.get(i1)
|
||||
g2 = self.generators.get(i2)
|
||||
if g1 is not None:
|
||||
self.generators[i2] = g1
|
||||
else:
|
||||
self.generators.pop(i2, None)
|
||||
if g2 is not None:
|
||||
self.generators[i1] = g2
|
||||
else:
|
||||
self.generators.pop(i1, None)
|
||||
|
||||
t1 = self.min_tokens.get(i1)
|
||||
t2 = self.min_tokens.get(i2)
|
||||
if t1 is not None:
|
||||
self.min_tokens[i2] = t1
|
||||
else:
|
||||
self.min_tokens.pop(i2, None)
|
||||
if t2 is not None:
|
||||
self.min_tokens[i1] = t2
|
||||
else:
|
||||
self.min_tokens.pop(i1, None)
|
||||
|
||||
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
|
||||
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
|
||||
self.logit_bias[i1], self.logit_bias[i2] =\
|
||||
self.logit_bias[i2], self.logit_bias[i1]
|
||||
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
self.allowed_token_ids_mask_cpu_tensor[i1], \
|
||||
self.allowed_token_ids_mask_cpu_tensor[i2] =\
|
||||
self.allowed_token_ids_mask_cpu_tensor[i2], \
|
||||
self.allowed_token_ids_mask_cpu_tensor[i1]
|
||||
self.block_table.swap_row(i1, i2)
|
||||
|
||||
def condense(self, empty_req_indices: list[int]) -> None:
|
||||
|
@ -456,8 +456,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Some attention backends (namely MLA) may want to separate requests
|
||||
# based on if the attention computation will be compute-bound or
|
||||
# memory-bound. This gives them a hook to do that.
|
||||
self.attn_metadata_builder.reorder_batch(self.input_batch,
|
||||
scheduler_output)
|
||||
modified_batch = self.attn_metadata_builder.reorder_batch(
|
||||
self.input_batch, scheduler_output)
|
||||
if modified_batch:
|
||||
self.input_batch.refresh_sampling_metadata()
|
||||
|
||||
# OPTIMIZATION: Start copying the block table first.
|
||||
# This way, we can overlap the copy with the following CPU operations.
|
||||
|
Loading…
x
Reference in New Issue
Block a user