1159 lines
40 KiB
Python
1159 lines
40 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
"""Kernel test utils"""
|
|
|
|
import itertools
|
|
import random
|
|
import unittest
|
|
from numbers import Number
|
|
from typing import (Any, Dict, List, NamedTuple, Optional, Sequence, Tuple,
|
|
Type, Union)
|
|
|
|
import pytest
|
|
import torch
|
|
from torch._prims_common import TensorLikeType
|
|
|
|
from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType
|
|
from vllm.model_executor.layers.activation import SiluAndMul
|
|
from vllm.platforms.interface import _Backend
|
|
from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL,
|
|
STR_XFORMERS_ATTN_VAL, make_tensor_with_pad)
|
|
|
|
# For now, disable "test_aot_dispatch_dynamic" since there are some
|
|
# bugs related to this test in PyTorch 2.4.
|
|
DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
|
|
"test_schema",
|
|
"test_autograd_registration",
|
|
"test_faketensor",
|
|
)
|
|
|
|
ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = (
|
|
"test_schema",
|
|
"test_autograd_registration",
|
|
"test_faketensor",
|
|
"test_aot_dispatch_dynamic",
|
|
)
|
|
|
|
|
|
class QKVInputs(NamedTuple):
|
|
'''
|
|
Data structure for representing unpacked attention inputs,
|
|
query/key/values and their sequence lengths.
|
|
|
|
Attributes:
|
|
|
|
* {query,key,value}: unpacked (batch_size x padded_seq_len x
|
|
num_heads x head_size) attention inputs
|
|
* q_seq_lens: query sequence lengths list
|
|
* kv_seq_lens: shared key/value sequence lengths list
|
|
'''
|
|
|
|
query: torch.Tensor
|
|
key: torch.Tensor
|
|
value: torch.Tensor
|
|
q_seq_lens: List[int]
|
|
kv_seq_lens: List[int]
|
|
|
|
|
|
class QKVO(NamedTuple):
|
|
'''
|
|
Data structure for representing unpacked attention inputs,
|
|
alongside unpacked known-correct attention output
|
|
|
|
Attributes:
|
|
|
|
* qkv: unpacked (batch_size x padded_seq_len x
|
|
num_heads x head_size) attention inputs
|
|
* ideal_output: unpacked (batch_size x padded_seq_len x
|
|
num_heads x head_size) known-correct attention output
|
|
'''
|
|
|
|
qkv: QKVInputs
|
|
ideal_output: torch.Tensor
|
|
|
|
|
|
class PackedQKVInputs(NamedTuple):
|
|
'''
|
|
Data structure for representing packed attention inputs
|
|
|
|
Attributes:
|
|
|
|
* {query,key,value}: packed (number_of_tokens x num_heads
|
|
x head_size) attention inputs
|
|
* q_start_loc_list: list of query start locations within packed tensor
|
|
* kv_start_loc_list: shared list of key/value start locations within
|
|
packed tensor
|
|
* q_seq_lens: query sequence lengths list
|
|
* kv_seq_lens: shared key/value sequence lengths list
|
|
'''
|
|
|
|
query: torch.Tensor
|
|
key: torch.Tensor
|
|
value: torch.Tensor
|
|
q_start_loc_list: Optional[List[int]]
|
|
kv_start_loc_list: Optional[List[int]]
|
|
q_seq_lens: Optional[List[int]]
|
|
kv_seq_lens: Optional[List[int]]
|
|
|
|
|
|
class PackedQKVO(NamedTuple):
|
|
'''
|
|
Data structure for representing packed attention inputs,
|
|
alongside packed known-correct attention output
|
|
|
|
Attributes:
|
|
|
|
* packed_qkv: packed (number_of_tokens x num_heads
|
|
x head_size) attention inputs
|
|
* ideal_output: packed (number_of_tokens x num_heads
|
|
x head_size) known-correct attention output
|
|
'''
|
|
|
|
packed_qkv: Optional[PackedQKVInputs]
|
|
ideal_output: torch.Tensor
|
|
|
|
|
|
class KVMemoryMap(NamedTuple):
|
|
'''
|
|
Data structure for encapsulating KV cache memory mapping.
|
|
|
|
Attributes:
|
|
|
|
* block_tables: KV cache block tables
|
|
* slot_mapping: mapping of sequence offset to physical address
|
|
'''
|
|
|
|
block_tables: torch.Tensor
|
|
slot_mapping: torch.Tensor
|
|
|
|
|
|
class PhaseTestParameters(NamedTuple):
|
|
'''
|
|
Data structure for encapsulating the test parameters
|
|
for a given test "phase" (prefill or decode phase) and attention
|
|
scenario (encoder, decoder-self, encoder/decoder-cross)
|
|
|
|
Attributes:
|
|
|
|
* packed_qkvo: packed (number_of_tokens x num_heads
|
|
x head_size) attention inputs & known-correct
|
|
output
|
|
* kv_mmap: KV cache memory mapping, specific to this test phase &
|
|
attention scenario
|
|
'''
|
|
|
|
packed_qkvo: PackedQKVO
|
|
kv_mmap: Optional[KVMemoryMap]
|
|
|
|
|
|
def maybe_make_int_tensor(
|
|
_list: Optional[List[int]],
|
|
device: Union[torch.device, str],
|
|
) -> torch.Tensor:
|
|
'''
|
|
Convert Python int list to a 1D int torch.Tensor on `device`
|
|
|
|
Returns:
|
|
|
|
* If _list is not None: 1D int torch.Tensor on `device`
|
|
* None otherwise
|
|
'''
|
|
return None if _list is None else torch.tensor(
|
|
_list, dtype=torch.int, device=device)
|
|
|
|
|
|
def maybe_make_long_tensor(
|
|
_list: Optional[List[int]],
|
|
device: Union[torch.device, str],
|
|
) -> torch.Tensor:
|
|
'''
|
|
Convert Python int list to a 1D long torch.Tensor on `device`
|
|
|
|
Returns:
|
|
|
|
* If _list is not None: 1D long torch.Tensor on `device`
|
|
* None otherwise
|
|
'''
|
|
return None if _list is None else torch.tensor(
|
|
_list, dtype=torch.long, device=device)
|
|
|
|
|
|
def maybe_max(_list: Optional[List]) -> Optional[Number]:
|
|
'''
|
|
Returns:
|
|
|
|
* If _list is not None: max(_list)
|
|
* None otherwise
|
|
'''
|
|
return None if _list is None else max(_list)
|
|
|
|
|
|
def make_causal_mask(
|
|
q_max_seq_len: int,
|
|
kv_max_seq_len: int,
|
|
) -> torch.Tensor:
|
|
'''
|
|
Create a q_max_seq_len x kv_max_seq_len causal mask
|
|
|
|
Arguments:
|
|
|
|
* q_max_seq_len: query max seq len
|
|
* kv_max_seq_len: key/value max seq len
|
|
|
|
Returns:
|
|
|
|
* 2D tensor, q_max_seq_len x kv_max_seq_len
|
|
'''
|
|
|
|
# Create a matrix where entry (i, j) is True if i >= j
|
|
mask = torch.triu(torch.ones(q_max_seq_len, kv_max_seq_len), diagonal=1)
|
|
# Replace True with float('-inf') and False with 0
|
|
mask = mask.masked_fill(mask == 1,
|
|
float('-inf')).masked_fill(mask == 0, 0.0)
|
|
return mask
|
|
|
|
|
|
def override_backend_env_variable(mpatch: pytest.MonkeyPatch,
|
|
backend_name: str) -> None:
|
|
'''
|
|
Override the environment variable indicating the vLLM backend temporarily,
|
|
using pytest monkeypatch to ensure that the env vars get
|
|
reset once the test context exits.
|
|
|
|
Arguments:
|
|
|
|
* mpatch: pytest monkeypatch instance
|
|
* backend_name: attention backend name to force
|
|
'''
|
|
mpatch.setenv(STR_BACKEND_ENV_VAR, backend_name)
|
|
|
|
|
|
def ref_masked_attention(query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
value: torch.Tensor,
|
|
scale: float,
|
|
custom_mask: Optional[torch.Tensor] = None,
|
|
q_seq_lens: Optional[List] = None,
|
|
kv_seq_lens: Optional[List] = None) -> torch.Tensor:
|
|
'''
|
|
"Golden" masked attention reference. Supports two types of masking:
|
|
|
|
* Basic attention mask, utilizing {q,kv}_seq_lens args to mask out
|
|
padding elements
|
|
* Custom attention mask, which can force an arbitrary mask tensor, i.e.
|
|
causal
|
|
|
|
Arguments:
|
|
|
|
* query: batch_size x q_padded_seq_len x num_heads x head_size
|
|
* key: batch_size x kv_padded_seq_len x num_heads x head_size
|
|
* value: batch_size x kv_padded_seq_len x num_heads x head_size
|
|
* scale: Attention scale factor
|
|
* custom_mask: custom attention mask; good place to inject a causal
|
|
attention mask
|
|
* q_seq_lens: list of unpadded query seq_lens for each batch index
|
|
* kv_seq_lens: list of unpadded key/value seq_lens for each batch index
|
|
|
|
Returns:
|
|
|
|
* Attention result, batch_size x q_padded_seq_len x num_heads x head_size
|
|
'''
|
|
|
|
assert q_seq_lens is not None
|
|
assert kv_seq_lens is not None
|
|
|
|
batch_size = query.shape[0]
|
|
assert (len(q_seq_lens) == batch_size)
|
|
assert (len(kv_seq_lens) == batch_size)
|
|
|
|
attn_weights = scale * torch.einsum("bqhd,bkhd->bhqk", query, key).float()
|
|
|
|
# Basic attention mask, derived from seq lens
|
|
if (q_seq_lens is not None) or (kv_seq_lens is not None):
|
|
attn_mask = torch.zeros_like(attn_weights)
|
|
if q_seq_lens is not None:
|
|
for bdx, plen in enumerate(q_seq_lens):
|
|
attn_mask[bdx, :, plen:, :] = -torch.inf
|
|
if kv_seq_lens is not None:
|
|
for bdx, plen in enumerate(kv_seq_lens):
|
|
attn_mask[bdx, :, :, plen:] = -torch.inf
|
|
|
|
attn_weights = attn_weights + attn_mask.float()
|
|
|
|
# Custom attention mask
|
|
if custom_mask is not None:
|
|
attn_weights = attn_weights + custom_mask.float()
|
|
|
|
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
|
out = torch.einsum("bhqk,bkhd->bqhd", attn_weights, value)
|
|
return out
|
|
|
|
|
|
def make_qkv(
|
|
batch_size: int,
|
|
max_q_seq_len: int,
|
|
max_kv_seq_len: Optional[int],
|
|
num_heads: int,
|
|
head_size: int,
|
|
device: Union[torch.device, str],
|
|
force_kv_seq_lens: Optional[List[int]] = None,
|
|
attn_type: AttentionType = AttentionType.ENCODER_DECODER,
|
|
force_max_len: bool = False,
|
|
) -> Tuple[QKVInputs, QKVInputs, QKVInputs]:
|
|
'''
|
|
Construct QKV test tensors for self- and cross-attention.
|
|
|
|
Generates three query/key/value triplets:
|
|
|
|
* "Baseline" query/key/value (for input to reference attention function)
|
|
* "Prefill" query/key/value (last sequence offset zero'd out, for use as
|
|
input to prefill kernel)
|
|
* "Decode" query/key/value (only the last sequence offset from baseline,
|
|
for use as input to decode kernel)
|
|
|
|
Each Q/K/V triplet is associated with a list of q seqlens and a list of k/v
|
|
seqlens
|
|
|
|
Arguments:
|
|
|
|
* batch_size
|
|
* max_q_seq_len: max query seq len
|
|
* max_kv_seq_len: max key/value seq len
|
|
* num_heads
|
|
* head_size
|
|
* is_encoder_decoder_attn: if True, query seqlen may differ from
|
|
key/value seqlen (as is often the case for cross-attention);
|
|
o/w, query/key/value seqlens match at each batch index
|
|
(max_kv_seq_len is unused)
|
|
* force_kv_seq_lens: if not None, overrides kv sequence lengths
|
|
* attn_type: encoder, decoder self, or enc/dec cross attention
|
|
* force_max_len: if True, all query seqlens are max_q_seq_len; o/w query
|
|
seqlens are random in [2,max_q_seq_lens]. Same for key/value seqlens
|
|
and max_kv_seq_len, unless forced by is_encoder_decoder_attn=False
|
|
* device: CPU or CUDA device
|
|
|
|
Returns:
|
|
|
|
* Overall QKVInputs structure (containing full unpacked Q/K/V tensors)
|
|
* Prefill QKVInputs structure (containing all but the last sequence offset)
|
|
* Decode QKVInputs structure (containing all only the last sequence offset)
|
|
'''
|
|
|
|
if force_max_len:
|
|
q_seq_lens = [max_q_seq_len for _ in range(batch_size)]
|
|
else:
|
|
q_seq_lens = [
|
|
random.randint(2, max_q_seq_len) for _ in range(batch_size)
|
|
]
|
|
kv_seq_lens = None
|
|
if force_kv_seq_lens is not None:
|
|
kv_seq_lens = force_kv_seq_lens
|
|
elif attn_type != AttentionType.ENCODER_DECODER:
|
|
# K,V seq lens match Q for self-attention
|
|
kv_seq_lens = q_seq_lens
|
|
else:
|
|
# K,V seq lens are distinct from Q seq lens & random
|
|
assert max_kv_seq_len is not None
|
|
if force_max_len:
|
|
kv_seq_lens = [max_kv_seq_len] * batch_size
|
|
else:
|
|
kv_seq_lens = [
|
|
random.randint(2, max_kv_seq_len) for _ in range(batch_size)
|
|
]
|
|
|
|
query = torch.rand(
|
|
(batch_size, max_q_seq_len, num_heads, head_size)).to(device)
|
|
key = torch.rand(
|
|
(batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
|
|
value = torch.rand(
|
|
(batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
|
|
|
|
prefill_query = torch.zeros(
|
|
(batch_size, max_q_seq_len, num_heads, head_size)).to(device)
|
|
prefill_key = torch.zeros(
|
|
(batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
|
|
prefill_value = torch.zeros(
|
|
(batch_size, max_kv_seq_len, num_heads, head_size)).to(device)
|
|
|
|
decode_query = torch.zeros(
|
|
(batch_size, 1, num_heads, head_size)).to(device)
|
|
decode_key = torch.zeros((batch_size, 1, num_heads, head_size)).to(device)
|
|
decode_value = torch.zeros(
|
|
(batch_size, 1, num_heads, head_size)).to(device)
|
|
|
|
for bdx, (q_seq_len, kv_seq_len) in enumerate(zip(q_seq_lens,
|
|
kv_seq_lens)):
|
|
query[bdx, q_seq_len:, :, :] = 0
|
|
key[bdx, kv_seq_len:, :, :] = 0
|
|
value[bdx, kv_seq_len:, :, :] = 0
|
|
|
|
prefill_query[bdx,
|
|
0:(q_seq_len - 1), :, :] = query[bdx,
|
|
0:(q_seq_len - 1), :, :]
|
|
prefill_key[bdx,
|
|
0:(kv_seq_len - 1), :, :] = key[bdx,
|
|
0:(kv_seq_len - 1), :, :]
|
|
prefill_value[bdx, 0:(kv_seq_len -
|
|
1), :, :] = value[bdx, 0:(kv_seq_len - 1), :, :]
|
|
|
|
decode_query[bdx, :, :, :] = query[bdx,
|
|
(q_seq_len - 1):q_seq_len, :, :]
|
|
decode_key[bdx, :, :, :] = key[bdx, (kv_seq_len - 1):kv_seq_len, :, :]
|
|
decode_value[bdx, :, :, :] = value[bdx,
|
|
(kv_seq_len - 1):kv_seq_len, :, :]
|
|
|
|
prefill_q_seq_lens = [plen - 1 for plen in q_seq_lens]
|
|
prefill_kv_seq_lens = [plen - 1 for plen in kv_seq_lens]
|
|
|
|
decode_q_seq_lens = [1 for _ in q_seq_lens]
|
|
decode_kv_seq_lens = [1 for _ in kv_seq_lens]
|
|
|
|
return (
|
|
QKVInputs(
|
|
query, # Overall QKV inputs
|
|
key,
|
|
value,
|
|
q_seq_lens,
|
|
kv_seq_lens),
|
|
QKVInputs(
|
|
prefill_query, # Prefill subset of QKV sequences
|
|
prefill_key,
|
|
prefill_value,
|
|
prefill_q_seq_lens,
|
|
prefill_kv_seq_lens),
|
|
QKVInputs(
|
|
decode_query, # Decode subset of KV sequences
|
|
decode_key,
|
|
decode_value,
|
|
decode_q_seq_lens,
|
|
decode_kv_seq_lens))
|
|
|
|
|
|
def pack_tensor(
|
|
unpacked_tensor: torch.Tensor, seq_lens: List[int],
|
|
device: Union[torch.device, str]) -> Tuple[torch.Tensor, List[int]]:
|
|
'''
|
|
Pack a batch_size x padded_seq_len x num_heads x head_size tensor into an
|
|
unpadded number_of_tokens x num_heads x head_size tensor, where
|
|
number_of_tokens = sum(seq_lens)
|
|
|
|
Arguments:
|
|
|
|
* unpacked_tensor: batch_size x padded_seq_len x num_heads x head_size
|
|
* seq_lens: list of token counts for each seq
|
|
* device: CPU or CUDA device
|
|
|
|
Returns
|
|
|
|
* packed_tensor: number_of_tokens x num_heads x head_size
|
|
* start_loc_list: start idx of each batch elt in packed_tensor; [0] +
|
|
list(itertools.accumulate(seq_lens))
|
|
'''
|
|
|
|
num_tok = sum(seq_lens)
|
|
num_heads = unpacked_tensor.shape[-2]
|
|
head_size = unpacked_tensor.shape[-1]
|
|
start_loc_list = [0] + list(itertools.accumulate(seq_lens))
|
|
packed_tensor = torch.zeros((num_tok, num_heads, head_size), device=device)
|
|
|
|
for bdx, (seq_len, start_loc) in enumerate(zip(seq_lens, start_loc_list)):
|
|
|
|
packed_tensor[start_loc:(
|
|
start_loc + seq_len), :, :] = unpacked_tensor[bdx, :seq_len, :, :]
|
|
|
|
return packed_tensor, start_loc_list
|
|
|
|
|
|
def pack_qkv(qkv: QKVInputs, device: Union[torch.device,
|
|
str]) -> PackedQKVInputs:
|
|
'''
|
|
Individually pack each of Q, K and V, each with dimensions batch_size x
|
|
padded_seq_len x num_heads x head_size, into respective number_of_tokens x
|
|
num_heads x head_size tensors.
|
|
|
|
For Q, number_of_tokens = sum(q_seq_lens).
|
|
|
|
For K and V, number_of_tokens = sum(kv_seq_lens)
|
|
|
|
Arguments:
|
|
|
|
* qkv: Unpacked (batch_size x padded_seq_len x num_heads x head_size)
|
|
attention inputs
|
|
* device: CPU or CUDA device
|
|
|
|
Returns
|
|
|
|
* Packed (number_of_tokens x num_heads x head_size) QKV inputs
|
|
derived from unpacked inputs
|
|
'''
|
|
|
|
if qkv.query is None:
|
|
packed_query = None
|
|
q_start_loc_list = None
|
|
else:
|
|
packed_query, q_start_loc_list = pack_tensor(qkv.query,
|
|
qkv.q_seq_lens,
|
|
device=device)
|
|
packed_key, kv_start_loc_list = pack_tensor(qkv.key,
|
|
qkv.kv_seq_lens,
|
|
device=device)
|
|
packed_value, _ = pack_tensor(qkv.value, qkv.kv_seq_lens, device=device)
|
|
return PackedQKVInputs(
|
|
packed_query, packed_key, packed_value, q_start_loc_list,
|
|
kv_start_loc_list,
|
|
(None if q_start_loc_list is None else qkv.q_seq_lens),
|
|
qkv.kv_seq_lens)
|
|
|
|
|
|
def make_backend(backend_name: str) -> AttentionBackend:
|
|
'''
|
|
Construct the backend instance determined by the backend_name string
|
|
argument.
|
|
|
|
"XFORMERS" -> construct xformers backend
|
|
|
|
TODO: other backends
|
|
|
|
Note: at time of writing the Attention wrapper automatically selects
|
|
its own backend for Attention.forward(); so the backend instance which
|
|
you generate with this function is not meant to be used for *running*
|
|
inference, but rather for generating compatible metadata structures
|
|
using backend.make_metadata()
|
|
|
|
|
|
Returns:
|
|
|
|
* Backend instance
|
|
'''
|
|
if backend_name == STR_XFORMERS_ATTN_VAL:
|
|
# NOTE: xFormers backend cannot be imported for CPU and AMD GPUs.
|
|
from vllm.attention.backends.xformers import XFormersBackend
|
|
return XFormersBackend()
|
|
elif backend_name == STR_FLASH_ATTN_VAL:
|
|
from vllm.attention.backends.flash_attn import FlashAttentionBackend
|
|
return FlashAttentionBackend()
|
|
|
|
raise AssertionError(
|
|
f"Unrecognized backend_name {backend_name} for unit test")
|
|
|
|
|
|
def _make_metadata_tensors(
|
|
seq_lens: Optional[List[int]],
|
|
context_lens: Optional[List[int]],
|
|
encoder_seq_lens: Optional[List[int]],
|
|
device: Union[torch.device, str],
|
|
) -> Tuple[torch.Tensor, torch.Tensor, Any, Any, Optional[torch.Tensor],
|
|
torch.Tensor, torch.Tensor, Optional[int]]:
|
|
'''
|
|
Build scalar & tensor values required to build attention metadata structure.
|
|
|
|
Arguments:
|
|
|
|
* seq_lens: list of token-counts for each decoder input seq
|
|
* context_lens: list of context length values for each seq
|
|
* encoder_seq_lens: list of token-counts for each encoder input seq
|
|
* device: CPU or CUDA device
|
|
|
|
Returns:
|
|
|
|
* seq_lens_tensor: decoder seq_lens list, as tensor
|
|
* context_lens_tensor: context_lens list, as tensor
|
|
* max_context_len: max(context_lens)
|
|
* max_seq_len: max(seq_lens)
|
|
* seq_start_loc: start idx of each sequence
|
|
* encoder_seq_lens_tensor: encoder seq_lens list, as tensor
|
|
* encoder_seq_start_loc: start idx of each encoder sequence
|
|
* max_encoder_seq_len: encoder seq_lens list, as tensor
|
|
'''
|
|
seq_lens_tensor = maybe_make_int_tensor(seq_lens, device)
|
|
context_lens_tensor = maybe_make_int_tensor(context_lens, device)
|
|
max_context_len = maybe_max(context_lens)
|
|
max_seq_len = maybe_max(seq_lens)
|
|
|
|
encoder_seq_lens_tensor = maybe_make_int_tensor(encoder_seq_lens, device)
|
|
max_encoder_seq_len = (None if encoder_seq_lens is None else
|
|
max(encoder_seq_lens))
|
|
|
|
seq_start_loc = None
|
|
|
|
if seq_lens_tensor is not None:
|
|
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
|
|
dtype=torch.int32,
|
|
device=seq_lens_tensor.device)
|
|
torch.cumsum(seq_lens_tensor,
|
|
dim=0,
|
|
dtype=seq_start_loc.dtype,
|
|
out=seq_start_loc[1:])
|
|
|
|
encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] + 1,
|
|
dtype=torch.int32,
|
|
device=encoder_seq_lens_tensor.device)
|
|
torch.cumsum(encoder_seq_lens_tensor,
|
|
dim=0,
|
|
dtype=encoder_seq_start_loc.dtype,
|
|
out=encoder_seq_start_loc[1:])
|
|
|
|
return (seq_lens_tensor, context_lens_tensor, max_context_len, max_seq_len,
|
|
seq_start_loc, encoder_seq_lens_tensor, encoder_seq_start_loc,
|
|
max_encoder_seq_len)
|
|
|
|
|
|
def make_kv_cache(num_blocks: int,
|
|
num_heads: int,
|
|
head_size: int,
|
|
block_size: int,
|
|
device: Union[torch.device, str],
|
|
backend: str,
|
|
default_val: float = 0.0) -> torch.Tensor:
|
|
'''
|
|
Create a fake KV cache.
|
|
|
|
Arguments:
|
|
|
|
* num_blocks: number of blocks in the KV cache
|
|
* num_heads: number of attention heads
|
|
* head_size: head dimension
|
|
* block_size: number of offsets within a block
|
|
* device: CPU or CUDA device
|
|
* default_val: initialization value for KV cache elements
|
|
|
|
Returns:
|
|
|
|
* kv_cache: 2 x num_blocks x (block_size * num_heads * head_size)
|
|
* for backend 'XFORMERS'
|
|
* kv_cache: 2 x num_blocks x block_size x num_heads x head_size
|
|
* for backend 'FLASH_ATTN'
|
|
'''
|
|
if backend == 'XFORMERS':
|
|
kv_cache = torch.rand(
|
|
(2, num_blocks, block_size * num_heads * head_size)).to(device)
|
|
elif backend == 'FLASH_ATTN':
|
|
kv_cache = torch.rand(
|
|
(2, num_blocks, block_size, num_heads, head_size)).to(device)
|
|
else:
|
|
raise ValueError(
|
|
f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or "
|
|
f"'FLASH_ATTN'.")
|
|
if default_val is not None:
|
|
kv_cache[:, :, :] = default_val
|
|
return kv_cache
|
|
|
|
|
|
def _num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int:
|
|
'''
|
|
Compute the minimum number of blocks required to hold num_tokens tokens,
|
|
given block_size
|
|
'''
|
|
return (num_tokens + block_size) // block_size
|
|
|
|
|
|
def make_empty_slot_mapping_tensor(device: Union[torch.device, str]):
|
|
return maybe_make_long_tensor([], device)
|
|
|
|
|
|
def make_empty_block_tables_tensor(device: Union[torch.device, str]):
|
|
return torch.tensor([], device=device)
|
|
|
|
|
|
def split_slot_mapping(slot_mapping_list: torch.Tensor, seq_lens: List[int],
|
|
device: Union[torch.device, str]):
|
|
'''
|
|
Split a slot mapping into valid prefill- and decode-phase slot mappings.
|
|
|
|
Context:
|
|
* Your goal is to test (1) prefill of N prompts, with prompt-lengths
|
|
{K_i \\forall i \\in [0,N)}, followed by (2) decoding of a single token
|
|
for all N prompts (N tokens total); the resultant sequence lengths
|
|
after decode would be {K_i + 1 for i \\in [0,N)}
|
|
* The test you want to do requires (1) having the prefill slot mapping
|
|
for all tokens present during prefill, the number of which is
|
|
M = \\sum_i{K_i}, and (2) having the decode slot mapping for all N
|
|
decoded tokens
|
|
|
|
This function consumes a single 1D slot mapping, which is the
|
|
concatenation of N slot mappings each of length K_i + 1 (corresponding
|
|
to the sequence lengths after decode), with a total length of
|
|
P = \\sum_i{K_i + 1} = M + N
|
|
|
|
The prefill-phase slot mapping results from excising the (K_i + 1)-th entry
|
|
from each of the N subsequences in the slot mapping (i.e. omitting the
|
|
decoded token's mapping.)
|
|
|
|
The N excised entries are appended to obtain the decode-phase slot mapping
|
|
|
|
Arguments:
|
|
|
|
* slot_mapping_list: Length-P 1D slot mapping (as List) reflecting all N
|
|
post-decode sequences
|
|
* seq_lens: List of N post-decode sequence lengths (K_i + 1 in the
|
|
description above)
|
|
* device: cuda, cpu, etc.
|
|
|
|
Returns:
|
|
|
|
* prefill_slot_mapping: Length-M 1D slot mapping (as Tensor)
|
|
reflecting all N prefill prompts
|
|
* decode_slot_mapping: Length-N 1D slot mapping (as Tensor) reflecting
|
|
all N decoded tokens
|
|
'''
|
|
|
|
prefill_slot_mapping = []
|
|
decode_slot_mapping = []
|
|
|
|
base_idx = 0
|
|
for seq_len in seq_lens:
|
|
prefill_slot_mapping.extend(slot_mapping_list[base_idx:(base_idx +
|
|
seq_len - 1)])
|
|
decode_slot_mapping.append(slot_mapping_list[base_idx + seq_len - 1])
|
|
base_idx += seq_len
|
|
|
|
return (maybe_make_long_tensor(prefill_slot_mapping, device),
|
|
maybe_make_long_tensor(decode_slot_mapping, device))
|
|
|
|
|
|
def make_block_tables_slot_mapping(
|
|
block_size: int,
|
|
seq_lens: List[int],
|
|
device: Union[torch.device, str],
|
|
block_base_addr: int = 0) -> Tuple[torch.Tensor, List[int], int]:
|
|
'''
|
|
Construct fake block tables & slot mappings.
|
|
|
|
For a sequence with num_tokens tokens the minimum number
|
|
of required KV cache blocks is
|
|
|
|
num_blocks = (num_tokens + block_size) // block_size
|
|
|
|
Then the minimum KV cache size in blocks is
|
|
|
|
total_cache_blocks = sum(num_blocks for all seqs)
|
|
|
|
Then, the blocktable mapping counts downward from
|
|
|
|
block_base_addr + total_cache_blocks
|
|
|
|
to
|
|
|
|
block_base_addr
|
|
|
|
|
|
The constructed block-tables and slot-mapping are sized to the
|
|
lengths of the sequences in their entirety (as reflected by seq_lens),
|
|
i.e. the total of prefill prompt tokens + decoded tokens.
|
|
|
|
Arguments:
|
|
|
|
* block_size: number of offsets per block
|
|
* seq_lens: list of token-counts for each sequence
|
|
* block_base_addr: the block table base address
|
|
* device: CPU or CUDA device
|
|
|
|
Return:
|
|
|
|
* block_tables_tensor: block table for sequence
|
|
* slot_mapping_list: slot mapping for sequence
|
|
* max_block_idx: the highest block address within this block table
|
|
'''
|
|
|
|
# Provision minimum number of KV cache blocks
|
|
num_blocks_list = [
|
|
_num_tokens_to_min_blocks(num_tokens, block_size)
|
|
for num_tokens in seq_lens
|
|
]
|
|
max_block_table_len = max(num_blocks_list)
|
|
block_table_pad_tokens = 10
|
|
|
|
block_tables = []
|
|
slot_mapping_list = []
|
|
# Compute uppermost address of block table
|
|
total_cache_blocks = sum(num_blocks_list)
|
|
block_base_idx = block_base_addr + total_cache_blocks
|
|
max_block_idx = block_base_idx
|
|
for sdx, num_tokens in enumerate(seq_lens):
|
|
num_blocks = num_blocks_list[sdx]
|
|
block_table = list(
|
|
range(block_base_idx, block_base_idx - num_blocks, -1))
|
|
for idx in range(num_tokens):
|
|
mapping_value = (
|
|
idx % block_size) + block_table[idx // block_size] * block_size
|
|
slot_mapping_list.append(mapping_value)
|
|
|
|
block_base_idx -= num_blocks
|
|
block_tables.append(block_table)
|
|
|
|
block_tables_tensor = make_tensor_with_pad(
|
|
block_tables,
|
|
max_len=max_block_table_len + block_table_pad_tokens,
|
|
pad=0,
|
|
dtype=torch.int,
|
|
device=device,
|
|
)
|
|
|
|
return (block_tables_tensor, slot_mapping_list, max_block_idx)
|
|
|
|
|
|
def make_test_metadata(
|
|
attn_backend: _Backend,
|
|
is_prompt: bool,
|
|
seq_lens: Optional[List[int]],
|
|
decoder_test_params: Optional[PhaseTestParameters],
|
|
device: Union[torch.device, str],
|
|
encoder_test_params: Optional[PhaseTestParameters] = None,
|
|
cross_test_params: Optional[PhaseTestParameters] = None
|
|
) -> AttentionMetadata:
|
|
'''
|
|
Construct fake attention metadata for a given test phase
|
|
(prefill-phase or decode-phase).
|
|
|
|
encoder_test_params and cross_test_params arguments allow encoder
|
|
attention and enc/dec cross-attention (respectively) to use distinct
|
|
metadata values from decoder self-attention (decoder_test_params.)
|
|
|
|
if encoder_test_params and cross_test_params are None, the attention
|
|
metadata will support decoder-only scenario.
|
|
|
|
Assumptions:
|
|
|
|
* No chunked prefill -> a batch is 100% prefill or 100% decode, never both
|
|
|
|
Arguments:
|
|
|
|
* attn_backend_name: Backend for sourcing attention kernels
|
|
* is_prompt: prefill if True, o/w decode
|
|
* seq_lens: list of token counts for each sequence
|
|
* decoder_test_params: decoder self-attention test params;
|
|
this function requires
|
|
kv_mmap (memory mapping) field
|
|
* device: CPU or CUDA device
|
|
* encoder_test_params: encoder attention test params;
|
|
this function requires encoder query
|
|
sequence lengths field. If None,
|
|
encoder query sequence lengths are
|
|
treated as None
|
|
* cross_test_params: enc/dec cross-attention test params;
|
|
this function requires kv_mmap field.
|
|
If None, KV cache memory map data
|
|
structures are treated as None
|
|
|
|
Return:
|
|
|
|
* AttentionMetadata structure
|
|
'''
|
|
|
|
# Decoder self-attention memory mapping
|
|
# decoder_test_params is None signals encoder-only
|
|
# scenario, so kv_mmap is None
|
|
kv_mmap = (None
|
|
if decoder_test_params is None else decoder_test_params.kv_mmap)
|
|
|
|
# This function constructs metadata assuming no chunked prefill,
|
|
# i.e. 100% prefill tokens or 100% decode tokens
|
|
#
|
|
# - If is_prompt, num_prefills_or_decodes is the number of prefills
|
|
# and num_prefill_or_decode_tokens is the number of prefill tokens
|
|
# - If not is_prompt, num_prefills_or_decodes is the number of decodes
|
|
# and num_prefill_or_decode_tokens is the number of decode tokens
|
|
#
|
|
# seq_lens is None signals encoder-only
|
|
# scenario, in which case num_prefills_or_decodes and
|
|
# num_prefill_or_decode_tokens are unused
|
|
num_prefills_or_decodes = (None if seq_lens is None else len(seq_lens))
|
|
|
|
num_prefill_or_decode_tokens = (None if seq_lens is None else (
|
|
sum(seq_lens) if is_prompt else len(seq_lens)))
|
|
|
|
# Seems for non-prefix-caching scenarios context_lens
|
|
# is never needed
|
|
context_lens = None
|
|
|
|
if encoder_test_params is None:
|
|
encoder_seq_lens = None
|
|
num_encoder_tokens = None
|
|
else:
|
|
# Encoder/decoder or encoder-only models only:
|
|
# * Extract encoder input sequence lengths
|
|
assert encoder_test_params.packed_qkvo.packed_qkv is not None
|
|
encoder_seq_lens = encoder_test_params.packed_qkvo.packed_qkv.q_seq_lens
|
|
num_encoder_tokens = (None if encoder_seq_lens is None else
|
|
(sum(encoder_seq_lens)))
|
|
|
|
if cross_test_params is None:
|
|
cross_kv_mmap = None
|
|
else:
|
|
# Encoder/decoder or encoder-only models only:
|
|
# * Extract *cross-attention* slot_mapping and block table
|
|
# (kv_mmap)
|
|
cross_kv_mmap = cross_test_params.kv_mmap
|
|
|
|
attn_backend_obj = make_backend(attn_backend.name)
|
|
|
|
if is_prompt:
|
|
# Prefill-phase scenario
|
|
|
|
num_prefills = num_prefills_or_decodes
|
|
num_prefill_tokens = num_prefill_or_decode_tokens
|
|
num_decode_tokens = 0
|
|
|
|
(
|
|
seq_lens_tensor,
|
|
context_lens_tensor,
|
|
_,
|
|
_,
|
|
seq_start_loc,
|
|
encoder_seq_lens_tensor,
|
|
encoder_seq_start_loc,
|
|
max_encoder_seq_len,
|
|
) = _make_metadata_tensors(seq_lens,
|
|
context_lens,
|
|
encoder_seq_lens,
|
|
device=device)
|
|
return attn_backend_obj.make_metadata(
|
|
num_prefills=num_prefills,
|
|
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
|
|
multi_modal_placeholder_index_maps=None,
|
|
enable_kv_scales_calculation=True,
|
|
num_prefill_tokens=num_prefill_tokens,
|
|
num_decode_tokens=num_decode_tokens,
|
|
seq_lens=seq_lens,
|
|
seq_lens_tensor=seq_lens_tensor,
|
|
seq_start_loc=seq_start_loc,
|
|
max_prefill_seq_len=None if seq_lens is None else max(seq_lens),
|
|
max_decode_seq_len=0,
|
|
context_lens_tensor=context_lens_tensor,
|
|
block_tables=(None if kv_mmap is None else kv_mmap.block_tables),
|
|
use_cuda_graph=False,
|
|
num_encoder_tokens=num_encoder_tokens,
|
|
encoder_seq_lens=encoder_seq_lens,
|
|
encoder_seq_lens_tensor=encoder_seq_lens_tensor,
|
|
encoder_seq_start_loc=encoder_seq_start_loc,
|
|
max_encoder_seq_len=max_encoder_seq_len,
|
|
cross_slot_mapping=(None if cross_kv_mmap is None else
|
|
cross_kv_mmap.slot_mapping),
|
|
cross_block_tables=(None if cross_kv_mmap is None else
|
|
cross_kv_mmap.block_tables))
|
|
|
|
else: # not is_prompt
|
|
# Decode-phase scenario
|
|
|
|
assert kv_mmap is not None
|
|
assert num_prefill_or_decode_tokens is not None
|
|
assert seq_lens is not None
|
|
|
|
num_prefills = 0
|
|
num_prefill_tokens = 0
|
|
num_decode_tokens = num_prefill_or_decode_tokens
|
|
|
|
(
|
|
seq_lens_tensor,
|
|
context_lens_tensor,
|
|
_,
|
|
_,
|
|
seq_start_loc,
|
|
encoder_seq_lens_tensor,
|
|
encoder_seq_start_loc,
|
|
max_encoder_seq_len,
|
|
) = _make_metadata_tensors(seq_lens,
|
|
context_lens,
|
|
encoder_seq_lens,
|
|
device=device)
|
|
|
|
return attn_backend_obj.make_metadata(
|
|
num_prefills=num_prefills,
|
|
slot_mapping=kv_mmap.slot_mapping,
|
|
multi_modal_placeholder_index_maps=None,
|
|
enable_kv_scales_calculation=True,
|
|
num_prefill_tokens=num_prefill_tokens,
|
|
num_decode_tokens=num_decode_tokens,
|
|
seq_lens=seq_lens,
|
|
seq_lens_tensor=seq_lens_tensor,
|
|
seq_start_loc=seq_start_loc,
|
|
max_prefill_seq_len=0,
|
|
max_decode_seq_len=max(seq_lens),
|
|
max_decode_query_len=1,
|
|
context_lens_tensor=context_lens_tensor,
|
|
block_tables=kv_mmap.block_tables,
|
|
use_cuda_graph=False,
|
|
num_encoder_tokens=num_encoder_tokens,
|
|
encoder_seq_lens=encoder_seq_lens,
|
|
encoder_seq_lens_tensor=encoder_seq_lens_tensor,
|
|
encoder_seq_start_loc=encoder_seq_start_loc,
|
|
max_encoder_seq_len=max_encoder_seq_len,
|
|
cross_slot_mapping=(None if cross_kv_mmap is None else
|
|
cross_kv_mmap.slot_mapping),
|
|
cross_block_tables=(None if cross_kv_mmap is None else
|
|
cross_kv_mmap.block_tables))
|
|
|
|
|
|
def assert_actual_matches_ideal(test_params: PhaseTestParameters,
|
|
output_under_test: torch.Tensor,
|
|
backend: str) -> None:
|
|
'''
|
|
Assert that observed output matches the ideal output
|
|
contained in the test parameters data structure.
|
|
|
|
Arguments:
|
|
|
|
* test_params: Test parameters including packed ideal output
|
|
* output_under_test: actually observed output value
|
|
'''
|
|
ideal_output = test_params.packed_qkvo.ideal_output
|
|
if backend == 'XFORMERS':
|
|
torch.testing.assert_close(ideal_output,
|
|
output_under_test.view_as(ideal_output))
|
|
|
|
elif backend == 'FLASH_ATTN':
|
|
# For FlashAttention override the accuracy thresholds to non default
|
|
# values since we notice a higher difference between the ideal and
|
|
# actual output.
|
|
torch.testing.assert_close(ideal_output,
|
|
output_under_test.view_as(ideal_output),
|
|
atol=0.01,
|
|
rtol=0.016)
|
|
else:
|
|
raise ValueError(
|
|
f"Unknown backend value: '{backend}'. Expected 'XFORMERS' or "
|
|
f"'FLASH_ATTN'.")
|
|
|
|
|
|
# Copied/modified from torch._refs.__init__.py
|
|
def fp8_allclose(
|
|
a: TensorLikeType,
|
|
b: TensorLikeType,
|
|
rtol: float = 1e-05,
|
|
atol: float = 1e-08,
|
|
equal_nan: bool = False,
|
|
) -> bool:
|
|
"""
|
|
Reference implementation of torch.allclose
|
|
"""
|
|
torch._refs._check_close_args(name="torch.allclose",
|
|
a=a,
|
|
b=b,
|
|
rtol=rtol,
|
|
atol=atol)
|
|
|
|
return bool(
|
|
torch.all(
|
|
torch.isclose(a.double(),
|
|
b.double(),
|
|
rtol=rtol,
|
|
atol=atol,
|
|
equal_nan=equal_nan)).item())
|
|
|
|
|
|
# Marlin MoE test utils
|
|
|
|
|
|
def stack_and_dev(tensors: List[torch.Tensor]):
|
|
dev = tensors[0].device
|
|
return torch.stack(tensors, dim=0).to(dev)
|
|
|
|
|
|
def compute_max_diff(output, output_ref):
|
|
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
|
torch.abs(output_ref))
|
|
|
|
|
|
def torch_moe(a, w1, w2, score, topk, expert_map):
|
|
B, D = a.shape
|
|
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
|
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
|
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
|
topk_weight, topk_ids = torch.topk(score, topk)
|
|
topk_weight = topk_weight.view(-1)
|
|
topk_ids = topk_ids.view(-1)
|
|
if expert_map is not None:
|
|
topk_ids = expert_map[topk_ids]
|
|
for i in range(w1.shape[0]):
|
|
mask = topk_ids == i
|
|
if mask.sum():
|
|
out[mask] = SiluAndMul()(
|
|
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
|
|
return (out.view(B, -1, w2.shape[1]) *
|
|
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
|
|
|
|
|
def torch_moe_single(a, w, score, topk):
|
|
B, D = a.shape
|
|
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
|
out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device)
|
|
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
|
_, topk_ids = torch.topk(score, topk)
|
|
topk_ids = topk_ids.view(-1)
|
|
for i in range(w.shape[0]):
|
|
mask = topk_ids == i
|
|
if mask.sum():
|
|
out[mask] = a[mask] @ w[i].transpose(0, 1)
|
|
return (out.view(B, -1, w.shape[1])).sum(dim=1)
|
|
|
|
|
|
# A special version of op check that has a restricted default set of test_utils
|
|
# and a patched version of allclose that supports fp8 types.
|
|
def opcheck(op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket,
|
|
torch._library.custom_ops.CustomOpDef],
|
|
args: Tuple[Any, ...],
|
|
kwargs: Optional[Dict[str, Any]] = None,
|
|
*,
|
|
test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS,
|
|
raise_exception: bool = True,
|
|
cond: bool = True) -> Dict[str, str]:
|
|
with unittest.mock.patch('torch.allclose', new=fp8_allclose):
|
|
return torch.library.opcheck(
|
|
op,
|
|
args,
|
|
kwargs,
|
|
test_utils=test_utils,
|
|
raise_exception=raise_exception) if cond else {}
|
|
|
|
|
|
# For testing quantized linear kernels
|
|
def to_fp8(tensor: torch.Tensor):
|
|
finfo = torch.finfo(torch.float8_e4m3fn)
|
|
return torch.round(tensor.clamp(
|
|
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
|
|
|
|
|
def to_int8(tensor: torch.Tensor):
|
|
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
|
|
|
|
|
def baseline_scaled_mm(a: torch.Tensor,
|
|
b: torch.Tensor,
|
|
scale_a: torch.Tensor,
|
|
scale_b: torch.Tensor,
|
|
out_dtype: Type[torch.dtype],
|
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
|
# We treat N-dimensional group scaling as extended numpy-style broadcasting
|
|
# in numpy simply stretches dimensions with an extent of 1 to match the
|
|
# the target shape by repeating the data along that dimension (broadcasting)
|
|
# , we extend these semantics to say if the extent of a dimension in the
|
|
# source shape is not 1 and does not match the target shape we repeat each
|
|
# element along that dimension src_shape[dim] // target_shape[dim] times
|
|
# example if we have:
|
|
# a = [[1, 2], and target_shape = (2, 4)
|
|
# [3, 4]]
|
|
# then we would expand a to:
|
|
# a = [[1, 1, 2, 2],
|
|
# [3, 3, 4, 4]]
|
|
# NOTE this function this function does not explicitly broadcast dimensions
|
|
# with an extent of 1, since this can be done implicitly by pytorch
|
|
def group_broadcast(t, shape):
|
|
for i, s in enumerate(shape):
|
|
if t.shape[i] != s and t.shape[i] != 1:
|
|
assert s % t.shape[i] == 0
|
|
t = t.unsqueeze(i + 1)\
|
|
.expand(*t.shape[:i+1], s // t.shape[i], *t.shape[i+1:])\
|
|
.flatten(i, i + 1)
|
|
return t
|
|
|
|
scale_a = group_broadcast(scale_a, a.shape)
|
|
scale_b = group_broadcast(scale_b, b.shape)
|
|
|
|
output = torch.mm((scale_a * a.to(dtype=torch.float32)),
|
|
(scale_b * b.to(dtype=torch.float32))).to(out_dtype)
|
|
|
|
if bias is not None:
|
|
output = output + bias
|
|
|
|
return output
|