[ROCm][Hardware][AMD] Use Triton Kernel for default FA on ROCm (#3643)
Co-authored-by: jpvillam <jpvillam@amd.com> Co-authored-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
e23a43aef8
commit
6c0b04515f
@ -23,6 +23,9 @@ RUN echo "FA_BRANCH is $FA_BRANCH"
|
|||||||
# In that case, we need to use the python reference attention implementation in vllm
|
# In that case, we need to use the python reference attention implementation in vllm
|
||||||
ARG BUILD_FA="1"
|
ARG BUILD_FA="1"
|
||||||
|
|
||||||
|
# whether to build triton on rocm
|
||||||
|
ARG BUILD_TRITON="1"
|
||||||
|
|
||||||
# Install some basic utilities
|
# Install some basic utilities
|
||||||
RUN apt-get update && apt-get install python3 python3-pip -y
|
RUN apt-get update && apt-get install python3 python3-pip -y
|
||||||
|
|
||||||
@ -75,6 +78,17 @@ RUN if [ "$BUILD_FA" = "1" ]; then \
|
|||||||
RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
|
RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
|
||||||
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
|
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
|
||||||
|
|
||||||
|
# build triton
|
||||||
|
RUN if [ "$BUILD_TRITON" = "1" ]; then \
|
||||||
|
mkdir -p libs \
|
||||||
|
&& cd libs \
|
||||||
|
&& pip uninstall -y triton \
|
||||||
|
&& git clone https://github.com/ROCm/triton.git \
|
||||||
|
&& cd triton/python \
|
||||||
|
&& pip3 install . \
|
||||||
|
&& cd ../..; \
|
||||||
|
fi
|
||||||
|
|
||||||
COPY ./ /app/vllm
|
COPY ./ /app/vllm
|
||||||
|
|
||||||
RUN python3 -m pip install --upgrade pip
|
RUN python3 -m pip install --upgrade pip
|
||||||
|
348
vllm/attention/backends/rocm_flash_attn.py
Normal file
348
vllm/attention/backends/rocm_flash_attn.py
Normal file
@ -0,0 +1,348 @@
|
|||||||
|
"""Attention layer ROCm GPUs."""
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
|
AttentionMetadata)
|
||||||
|
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||||
|
PagedAttentionMetadata)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ROCmFlashAttentionBackend(AttentionBackend):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
|
||||||
|
return ROCmFlashAttentionImpl
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_metadata(*args, **kwargs) -> "ROCmFlashAttentionMetadata":
|
||||||
|
return ROCmFlashAttentionMetadata(*args, **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_kv_cache_shape(
|
||||||
|
num_blocks: int,
|
||||||
|
block_size: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
) -> Tuple[int, ...]:
|
||||||
|
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||||
|
num_kv_heads, head_size)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def swap_blocks(
|
||||||
|
src_kv_cache: torch.Tensor,
|
||||||
|
dst_kv_cache: torch.Tensor,
|
||||||
|
src_to_dst: Dict[int, int],
|
||||||
|
) -> None:
|
||||||
|
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def copy_blocks(
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
src_to_dists: Dict[int, List[int]],
|
||||||
|
) -> None:
|
||||||
|
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||||
|
"""Metadata for FlashAttentionBackend.
|
||||||
|
|
||||||
|
NOTE: Any python object stored here is not updated when it is
|
||||||
|
cuda-graph replayed. If you have values that need to be changed
|
||||||
|
dynamically, it should be stored in tensor. The tensor has to be
|
||||||
|
updated from `CUDAGraphRunner.forward` API.
|
||||||
|
"""
|
||||||
|
# Currently, input sequences can only contain all prompts
|
||||||
|
# or all decoding. True if all sequences are prompts.
|
||||||
|
is_prompt: bool
|
||||||
|
# (batch_size,). The prompt length per sequence. None if it is a decoding.
|
||||||
|
prompt_lens: Optional[List[int]]
|
||||||
|
# prompt_lens stored as a tensor.
|
||||||
|
prompt_lens_tensor: Optional[torch.Tensor]
|
||||||
|
# The number of prompt tokens. Doesn't include padding.
|
||||||
|
num_prompt_tokens: int
|
||||||
|
# The number of generation tokens. Doesn't include padding.
|
||||||
|
num_generation_tokens: int
|
||||||
|
|
||||||
|
# NOTE(sang): Definition of context_len, subquery_len, and seqlen.
|
||||||
|
# |---------- N-1 iteration --------|
|
||||||
|
# |---------------- N iteration ---------------------|
|
||||||
|
# |- tokenA -|......................|-- newTokens ---|
|
||||||
|
# |---------- context_len ----------|
|
||||||
|
# |-------------------- seqlen ----------------------|
|
||||||
|
# |- subquery_len -|
|
||||||
|
|
||||||
|
# WARNING(sang): context_len has different definition depending on if it is
|
||||||
|
# prefill vs decoding. When it is prefill, it doesn't include new tokens.
|
||||||
|
# When it is for decoding, it includes a new token.
|
||||||
|
|
||||||
|
# Maximum subquery length in the batch.
|
||||||
|
max_subquery_len: Optional[int]
|
||||||
|
# Maximum prompt length in the batch.
|
||||||
|
max_prompt_len: Optional[int]
|
||||||
|
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||||
|
# the batch, used to index into subquery. E.g., if the subquery length
|
||||||
|
# is [4, 6], it is [0, 4, 10].
|
||||||
|
subquery_start_loc: Optional[torch.Tensor]
|
||||||
|
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||||
|
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||||
|
# [4, 6], it is [0, 4, 10].
|
||||||
|
seq_start_loc: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
# Whether or not if cuda graph is enabled.
|
||||||
|
# Cuda-graph is currently enabled for decoding only.
|
||||||
|
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||||
|
use_cuda_graph: bool
|
||||||
|
|
||||||
|
|
||||||
|
class ROCmFlashAttentionImpl(AttentionImpl):
|
||||||
|
"""
|
||||||
|
If the input tensors contain prompt tokens, the layout is as follows:
|
||||||
|
|<--------------- num_prompt_tokens -------------->|
|
||||||
|
|<--prompt_0-->|<--prompt_1-->|...|<--prompt_N-1-->|
|
||||||
|
|
||||||
|
Otherwise, the layout is as follows:
|
||||||
|
|<------------------ num_generation_tokens (M) ----------------->|
|
||||||
|
|<--generation_0-->|..........|<--generation_M-1-->|<--padding-->|
|
||||||
|
|
||||||
|
Generation tokens can contain padding when cuda-graph is used.
|
||||||
|
Currently, prompt tokens don't contain any padding.
|
||||||
|
|
||||||
|
The prompts might have different lengths, while the generation tokens
|
||||||
|
always have length 1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
num_kv_heads: Optional[int] = None,
|
||||||
|
alibi_slopes: Optional[List[float]] = None,
|
||||||
|
sliding_window: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_size = head_size
|
||||||
|
self.scale = float(scale)
|
||||||
|
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||||
|
self.sliding_window = ((sliding_window, sliding_window)
|
||||||
|
if sliding_window is not None else (-1, -1))
|
||||||
|
if alibi_slopes is not None:
|
||||||
|
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||||
|
self.alibi_slopes = alibi_slopes
|
||||||
|
|
||||||
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
|
||||||
|
suppored_head_sizes = PagedAttention.get_supported_head_sizes()
|
||||||
|
if head_size not in suppored_head_sizes:
|
||||||
|
raise ValueError(
|
||||||
|
f"Head size {head_size} is not supported by PagedAttention. "
|
||||||
|
f"Supported head sizes are: {suppored_head_sizes}.")
|
||||||
|
|
||||||
|
self.use_naive_attn = torch.cuda.get_device_capability()[0] != 9
|
||||||
|
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
|
||||||
|
self.use_triton_flash_attn = (os.environ.get(
|
||||||
|
"VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1"))
|
||||||
|
if self.use_naive_attn:
|
||||||
|
# AMD Radeon 7900 series (gfx1100) currently does not support
|
||||||
|
# xFormers nor FlashAttention. As a temporary workaround, we use
|
||||||
|
# naive PyTorch implementation of attention.
|
||||||
|
self.attn_fuc = _naive_attention()
|
||||||
|
logger.debug("Using naive attention in ROCmBackend")
|
||||||
|
elif self.use_triton_flash_attn:
|
||||||
|
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
|
||||||
|
triton_attention)
|
||||||
|
self.attn_func = triton_attention
|
||||||
|
logger.debug("Using Triton FA in ROCmBackend")
|
||||||
|
else:
|
||||||
|
from flash_attn import flash_attn_varlen_func # noqa: F401
|
||||||
|
self.attn_func = flash_attn_varlen_func
|
||||||
|
logger.debug("Using CK FA in ROCmBackend")
|
||||||
|
|
||||||
|
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
|
||||||
|
tokens, n_kv_heads, head_dim = x.shape
|
||||||
|
return (x[:, :,
|
||||||
|
None, :].expand(tokens, n_kv_heads, n_rep,
|
||||||
|
head_dim).reshape(tokens, n_kv_heads * n_rep,
|
||||||
|
head_dim))
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: ROCmFlashAttentionMetadata,
|
||||||
|
kv_scale: float = 1.0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Forward pass with FlashAttention and PagedAttention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: shape = [num_tokens, num_heads * head_size]
|
||||||
|
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
|
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
|
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||||
|
attn_metadata: Metadata for attention.
|
||||||
|
Returns:
|
||||||
|
shape = [num_tokens, num_heads * head_size]
|
||||||
|
"""
|
||||||
|
num_tokens, hidden_size = query.shape
|
||||||
|
# Reshape the query, key, and value tensors.
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
|
||||||
|
if kv_cache is not None:
|
||||||
|
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||||
|
kv_cache, self.num_kv_heads, self.head_size)
|
||||||
|
|
||||||
|
# Reshape the input keys and values and store them in the cache.
|
||||||
|
# If kv_cache is not provided, the new key and value tensors are
|
||||||
|
# not cached. This happens during the initial memory profiling run.
|
||||||
|
PagedAttention.write_to_paged_cache(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
attn_metadata.slot_mapping,
|
||||||
|
attn_metadata.kv_cache_dtype,
|
||||||
|
kv_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
if attn_metadata.is_prompt:
|
||||||
|
# Prompt run.
|
||||||
|
if kv_cache is None or attn_metadata.block_tables.numel() == 0:
|
||||||
|
# triton attention
|
||||||
|
# When block_tables are not filled, it means q and k are the
|
||||||
|
# prompt, and they have the same length.
|
||||||
|
if self.use_naive_attn or self.use_triton_flash_attn:
|
||||||
|
if self.num_kv_heads != self.num_heads:
|
||||||
|
# Interleave for MQA workaround.
|
||||||
|
key = self.repeat_kv(key, self.num_queries_per_kv)
|
||||||
|
value = self.repeat_kv(value, self.num_queries_per_kv)
|
||||||
|
if self.use_naive_attn:
|
||||||
|
output = self.attn_fuc(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attn_metadata.prompt_lens,
|
||||||
|
self.scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output, _ = self.attn_func(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
None,
|
||||||
|
attn_metadata.seq_start_loc,
|
||||||
|
attn_metadata.seq_start_loc,
|
||||||
|
attn_metadata.max_prompt_len,
|
||||||
|
attn_metadata.max_prompt_len,
|
||||||
|
True,
|
||||||
|
self.scale,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output = self.attn_func(
|
||||||
|
q=query,
|
||||||
|
k=key,
|
||||||
|
v=value,
|
||||||
|
cu_seqlens_q=attn_metadata.seq_start_loc,
|
||||||
|
cu_seqlens_k=attn_metadata.seq_start_loc,
|
||||||
|
max_seqlen_q=attn_metadata.max_prompt_len,
|
||||||
|
max_seqlen_k=attn_metadata.max_prompt_len,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# prefix-enabled attention
|
||||||
|
output = PagedAttention.forward_prefix(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
attn_metadata.block_tables,
|
||||||
|
attn_metadata.subquery_start_loc,
|
||||||
|
attn_metadata.prompt_lens_tensor,
|
||||||
|
attn_metadata.context_lens,
|
||||||
|
attn_metadata.max_subquery_len,
|
||||||
|
self.alibi_slopes,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Decoding run.
|
||||||
|
output = PagedAttention.forward_decode(
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
attn_metadata.block_tables,
|
||||||
|
attn_metadata.context_lens,
|
||||||
|
attn_metadata.max_context_len,
|
||||||
|
attn_metadata.kv_cache_dtype,
|
||||||
|
self.num_kv_heads,
|
||||||
|
self.scale,
|
||||||
|
self.alibi_slopes,
|
||||||
|
kv_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape the output tensor.
|
||||||
|
return output.view(num_tokens, hidden_size)
|
||||||
|
|
||||||
|
|
||||||
|
def _naive_attention(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
prompt_lens: List[int],
|
||||||
|
scale: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
num_tokens = query.shape[0]
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
start = 0
|
||||||
|
for _, prompt_len in enumerate(prompt_lens):
|
||||||
|
end = start + prompt_len
|
||||||
|
out = _naive_masked_attention(
|
||||||
|
query[None, start:end],
|
||||||
|
key[None, start:end],
|
||||||
|
value[None, start:end],
|
||||||
|
scale,
|
||||||
|
)
|
||||||
|
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||||
|
output[start:end].copy_(out)
|
||||||
|
start += prompt_len
|
||||||
|
|
||||||
|
# Using view got RuntimeError: view size is not compatible
|
||||||
|
# with input tensor's size and stride (at least one
|
||||||
|
# dimension spans across two contiguous subspaces).
|
||||||
|
# Use reshape instead.
|
||||||
|
return output.reshape(num_tokens, -1)
|
||||||
|
|
||||||
|
|
||||||
|
def _naive_masked_attention(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
scale: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
seq_len, _, _ = query.shape
|
||||||
|
attn_mask = torch.triu(torch.ones(seq_len,
|
||||||
|
seq_len,
|
||||||
|
dtype=query.dtype,
|
||||||
|
device=query.device),
|
||||||
|
diagonal=1)
|
||||||
|
attn_mask = attn_mask * torch.finfo(query.dtype).min
|
||||||
|
|
||||||
|
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
|
||||||
|
attn_weights = attn_weights + attn_mask.float()
|
||||||
|
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
||||||
|
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
|
||||||
|
return out
|
@ -1,5 +1,4 @@
|
|||||||
"""Attention layer with xFormers and PagedAttention."""
|
"""Attention layer with xFormers and PagedAttention."""
|
||||||
import importlib
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple, Type
|
from typing import Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
@ -14,7 +13,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|||||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||||
PagedAttentionMetadata)
|
PagedAttentionMetadata)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import is_hip
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -166,11 +164,6 @@ class XFormersImpl(AttentionImpl):
|
|||||||
f"Head size {head_size} is not supported by PagedAttention. "
|
f"Head size {head_size} is not supported by PagedAttention. "
|
||||||
f"Supported head sizes are: {suppored_head_sizes}.")
|
f"Supported head sizes are: {suppored_head_sizes}.")
|
||||||
|
|
||||||
# AMD Radeon 7900 series (gfx1100) currently does not support xFormers
|
|
||||||
# nor FlashAttention. As a temporary workaround, we use naive PyTorch
|
|
||||||
# implementation of attention.
|
|
||||||
self.use_naive_attention = _check_use_naive_attention()
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@ -233,30 +226,6 @@ class XFormersImpl(AttentionImpl):
|
|||||||
self.num_queries_per_kv,
|
self.num_queries_per_kv,
|
||||||
value.shape[-1])
|
value.shape[-1])
|
||||||
|
|
||||||
if self.use_naive_attention:
|
|
||||||
output = torch.empty_like(query)
|
|
||||||
start = 0
|
|
||||||
for _, prompt_len in enumerate(attn_metadata.prompt_lens):
|
|
||||||
end = start + prompt_len
|
|
||||||
out = _naive_masked_attention(
|
|
||||||
query[None, start:end],
|
|
||||||
key[None, start:end],
|
|
||||||
value[None, start:end],
|
|
||||||
self.num_heads,
|
|
||||||
self.num_kv_heads,
|
|
||||||
self.head_size,
|
|
||||||
self.scale,
|
|
||||||
)
|
|
||||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
|
||||||
output[start:end].copy_(out)
|
|
||||||
start += prompt_len
|
|
||||||
|
|
||||||
# Using view got RuntimeError: view size is not compatible
|
|
||||||
# with input tensor's size and stride (at least one
|
|
||||||
# dimension spans across two contiguous subspaces).
|
|
||||||
# Use reshape instead.
|
|
||||||
return output.reshape(num_tokens, hidden_size)
|
|
||||||
|
|
||||||
output = self._run_memory_efficient_xformers_forward(
|
output = self._run_memory_efficient_xformers_forward(
|
||||||
query, key, value, attn_metadata)
|
query, key, value, attn_metadata)
|
||||||
else:
|
else:
|
||||||
@ -329,8 +298,6 @@ class XFormersImpl(AttentionImpl):
|
|||||||
self.alibi_slopes, self.num_kv_heads, query.dtype,
|
self.alibi_slopes, self.num_kv_heads, query.dtype,
|
||||||
attn_metadata.prompt_lens)
|
attn_metadata.prompt_lens)
|
||||||
|
|
||||||
op = xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if (
|
|
||||||
is_hip()) else None
|
|
||||||
# No alibi slopes.
|
# No alibi slopes.
|
||||||
# TODO(woosuk): Too many view operations. Let's try to reduce
|
# TODO(woosuk): Too many view operations. Let's try to reduce
|
||||||
# them in the future for code readability.
|
# them in the future for code readability.
|
||||||
@ -344,8 +311,7 @@ class XFormersImpl(AttentionImpl):
|
|||||||
value,
|
value,
|
||||||
attn_bias=attn_metadata.attn_bias[0],
|
attn_bias=attn_metadata.attn_bias[0],
|
||||||
p=0.0,
|
p=0.0,
|
||||||
scale=self.scale,
|
scale=self.scale)
|
||||||
op=op)
|
|
||||||
|
|
||||||
return out.view_as(query)
|
return out.view_as(query)
|
||||||
|
|
||||||
@ -363,8 +329,7 @@ class XFormersImpl(AttentionImpl):
|
|||||||
value[None, start:end],
|
value[None, start:end],
|
||||||
attn_bias=attn_metadata.attn_bias[i],
|
attn_bias=attn_metadata.attn_bias[i],
|
||||||
p=0.0,
|
p=0.0,
|
||||||
scale=self.scale,
|
scale=self.scale)
|
||||||
op=op)
|
|
||||||
# TODO(woosuk): Unnecessary copy. Optimize.
|
# TODO(woosuk): Unnecessary copy. Optimize.
|
||||||
output[start:end].copy_(out.squeeze(0))
|
output[start:end].copy_(out.squeeze(0))
|
||||||
start += prompt_len
|
start += prompt_len
|
||||||
@ -405,42 +370,3 @@ def _make_alibi_bias(
|
|||||||
attn_biases.append(LowerTriangularMaskWithTensorBias(bias))
|
attn_biases.append(LowerTriangularMaskWithTensorBias(bias))
|
||||||
|
|
||||||
return attn_biases
|
return attn_biases
|
||||||
|
|
||||||
|
|
||||||
def _check_use_naive_attention() -> bool:
|
|
||||||
if not is_hip():
|
|
||||||
return False
|
|
||||||
# For ROCm, check whether flash attention is installed or not.
|
|
||||||
use_naive_attention = importlib.util.find_spec("flash_attn") is None
|
|
||||||
if use_naive_attention:
|
|
||||||
logger.warning("flash_attn is not installed. Using naive attention. "
|
|
||||||
"This will take significantly more GPU memory.")
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _naive_masked_attention(
|
|
||||||
query: torch.Tensor,
|
|
||||||
key: torch.Tensor,
|
|
||||||
value: torch.Tensor,
|
|
||||||
num_heads: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
scale: float,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
query = query.view(-1, num_heads, head_size)
|
|
||||||
key = key.view(-1, num_kv_heads, head_size)
|
|
||||||
value = value.view(-1, num_kv_heads, head_size)
|
|
||||||
seq_len, _, _ = query.shape
|
|
||||||
attn_mask = torch.triu(torch.ones(seq_len,
|
|
||||||
seq_len,
|
|
||||||
dtype=query.dtype,
|
|
||||||
device=query.device),
|
|
||||||
diagonal=1)
|
|
||||||
attn_mask = attn_mask * torch.finfo(query.dtype).min
|
|
||||||
|
|
||||||
attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float()
|
|
||||||
attn_weights = attn_weights + attn_mask.float()
|
|
||||||
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
|
||||||
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
|
|
||||||
return out
|
|
||||||
|
809
vllm/attention/ops/triton_flash_attention.py
Normal file
809
vllm/attention/ops/triton_flash_attention.py
Normal file
@ -0,0 +1,809 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
"""
|
||||||
|
Fused Attention
|
||||||
|
===============
|
||||||
|
|
||||||
|
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao
|
||||||
|
(https://tridao.me/publications/flash2/flash2.pdf)
|
||||||
|
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
|
||||||
|
|
||||||
|
Features supported:
|
||||||
|
|
||||||
|
1) Fwd with causal masking
|
||||||
|
2) Any sequence lengths without padding (currently fwd kernel only)
|
||||||
|
3) Support for different sequence lengths for q and k
|
||||||
|
4) Nested tensor API currently does not support dropout or bias.
|
||||||
|
|
||||||
|
Not currently supported:
|
||||||
|
|
||||||
|
1) Non power of two head dims
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
torch_dtype: tl.constexpr = torch.float16
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def cdiv_fn(x, y):
|
||||||
|
return (x + y - 1) // y
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def max_fn(x, y):
|
||||||
|
return tl.math.max(x, y)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||||
|
ms = tl.arange(0, m)
|
||||||
|
ns = tl.arange(0, n)
|
||||||
|
return philox_offset + ms[:, None] * stride + ns[None, :]
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||||
|
rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n,
|
||||||
|
stride).to(tl.uint32)
|
||||||
|
# TODO: use tl.randint for better performance
|
||||||
|
return tl.rand(philox_seed, rng_offsets)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
|
||||||
|
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n,
|
||||||
|
stride)
|
||||||
|
rng_keep = rng_output > dropout_p
|
||||||
|
return rng_keep
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def load_fn(block_ptr, first, second, pad):
|
||||||
|
if first and second:
|
||||||
|
tensor = tl.load(block_ptr, boundary_check=(0, 1), padding_option=pad)
|
||||||
|
elif first:
|
||||||
|
tensor = tl.load(block_ptr, boundary_check=(0, ), padding_option=pad)
|
||||||
|
elif second:
|
||||||
|
tensor = tl.load(block_ptr, boundary_check=(1, ), padding_option=pad)
|
||||||
|
else:
|
||||||
|
tensor = tl.load(block_ptr)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _attn_fwd_inner(
|
||||||
|
acc,
|
||||||
|
l_i,
|
||||||
|
m_i,
|
||||||
|
q,
|
||||||
|
K_block_ptr,
|
||||||
|
V_block_ptr,
|
||||||
|
start_m,
|
||||||
|
actual_seqlen_k,
|
||||||
|
dropout_p,
|
||||||
|
philox_seed,
|
||||||
|
batch_philox_offset,
|
||||||
|
encoded_softmax_block_ptr,
|
||||||
|
block_min,
|
||||||
|
block_max,
|
||||||
|
offs_n_causal,
|
||||||
|
masked_blocks,
|
||||||
|
n_extra_tokens,
|
||||||
|
bias_ptr,
|
||||||
|
IS_CAUSAL: tl.constexpr,
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
OFFS_M: tl.constexpr,
|
||||||
|
OFFS_N: tl.constexpr,
|
||||||
|
PRE_LOAD_V: tl.constexpr,
|
||||||
|
MASK_STEPS: tl.constexpr,
|
||||||
|
ENABLE_DROPOUT: tl.constexpr,
|
||||||
|
RETURN_ENCODED_SOFTMAX: tl.constexpr,
|
||||||
|
PADDED_HEAD: tl.constexpr,
|
||||||
|
):
|
||||||
|
# loop over k, v, and update accumulator
|
||||||
|
for start_n in range(block_min, block_max, BLOCK_N):
|
||||||
|
# For padded blocks, we will overrun the tensor size if
|
||||||
|
# we load all BLOCK_N. For others, the blocks are all within range.
|
||||||
|
k = load_fn(
|
||||||
|
K_block_ptr,
|
||||||
|
PADDED_HEAD,
|
||||||
|
MASK_STEPS and (n_extra_tokens != 0),
|
||||||
|
"zero",
|
||||||
|
)
|
||||||
|
if PRE_LOAD_V:
|
||||||
|
v = load_fn(
|
||||||
|
V_block_ptr,
|
||||||
|
MASK_STEPS and (n_extra_tokens != 0),
|
||||||
|
PADDED_HEAD,
|
||||||
|
"zero",
|
||||||
|
)
|
||||||
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
|
# We start from end of seqlen_k so only the first iteration would need
|
||||||
|
# to be checked for padding if it is not a multiple of block_n
|
||||||
|
# TODO: This can be optimized to only be true for the padded block.
|
||||||
|
if MASK_STEPS: # noqa: SIM102
|
||||||
|
# If this is the last block / iteration, we want to
|
||||||
|
# mask if the sequence length is not a multiple of block size
|
||||||
|
# a solution is to always do BLOCK_M // BLOCK_N + 1 steps
|
||||||
|
# if not is_modulo_mn. last step might get wasted but that is okay.
|
||||||
|
# check if this masking works for that case.
|
||||||
|
if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0):
|
||||||
|
boundary_m = tl.full([BLOCK_M],
|
||||||
|
actual_seqlen_k,
|
||||||
|
dtype=tl.int32)
|
||||||
|
size_n = start_n + OFFS_N[None, :]
|
||||||
|
mask = size_n < boundary_m[:, None]
|
||||||
|
qk = tl.where(mask, qk, float("-inf"))
|
||||||
|
if IS_CAUSAL:
|
||||||
|
causal_boundary = start_n + offs_n_causal
|
||||||
|
causal_mask = OFFS_M[:, None] >= causal_boundary[None, :]
|
||||||
|
qk = tl.where(causal_mask, qk, float("-inf"))
|
||||||
|
# -- compute qk ----
|
||||||
|
qk += tl.dot(q, k)
|
||||||
|
if bias_ptr is not None:
|
||||||
|
bias = load_fn(bias_ptr, False, MASK_STEPS
|
||||||
|
and (n_extra_tokens != 0), "zero")
|
||||||
|
# While bias is added after multiplying qk with sm_scale, our
|
||||||
|
# optimization to use 2^x instead of e^x results in an additional
|
||||||
|
# scale factor of log2(e) which we must also multiply the bias with.
|
||||||
|
qk += bias * 1.44269504089
|
||||||
|
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
||||||
|
qk = qk - m_ij[:, None]
|
||||||
|
p = tl.math.exp2(qk)
|
||||||
|
|
||||||
|
# CAVEAT: Must update l_ij before applying dropout
|
||||||
|
l_ij = tl.sum(p, 1)
|
||||||
|
if ENABLE_DROPOUT:
|
||||||
|
philox_offset = (batch_philox_offset +
|
||||||
|
start_m * BLOCK_M * actual_seqlen_k + start_n -
|
||||||
|
BLOCK_N)
|
||||||
|
keep = dropout_mask(
|
||||||
|
philox_seed,
|
||||||
|
philox_offset,
|
||||||
|
dropout_p,
|
||||||
|
BLOCK_M,
|
||||||
|
BLOCK_N,
|
||||||
|
actual_seqlen_k,
|
||||||
|
)
|
||||||
|
if RETURN_ENCODED_SOFTMAX:
|
||||||
|
tl.store(
|
||||||
|
encoded_softmax_block_ptr,
|
||||||
|
tl.where(keep, p,
|
||||||
|
-p).to(encoded_softmax_block_ptr.type.element_ty),
|
||||||
|
)
|
||||||
|
p = tl.where(keep, p, 0.0)
|
||||||
|
elif RETURN_ENCODED_SOFTMAX:
|
||||||
|
tl.store(
|
||||||
|
encoded_softmax_block_ptr,
|
||||||
|
p.to(encoded_softmax_block_ptr.type.element_ty),
|
||||||
|
)
|
||||||
|
# -- update output accumulator --
|
||||||
|
alpha = tl.math.exp2(m_i - m_ij)
|
||||||
|
acc = acc * alpha[:, None]
|
||||||
|
if not PRE_LOAD_V:
|
||||||
|
v = load_fn(
|
||||||
|
V_block_ptr,
|
||||||
|
MASK_STEPS and (n_extra_tokens != 0),
|
||||||
|
PADDED_HEAD,
|
||||||
|
"zero",
|
||||||
|
)
|
||||||
|
# -- update m_i and l_i
|
||||||
|
l_i = l_i * alpha + l_ij
|
||||||
|
# update m_i and l_i
|
||||||
|
m_i = m_ij
|
||||||
|
acc += tl.dot(p.to(V_block_ptr.type.element_ty), v)
|
||||||
|
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
||||||
|
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
||||||
|
if bias_ptr is not None:
|
||||||
|
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
|
||||||
|
if RETURN_ENCODED_SOFTMAX:
|
||||||
|
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
|
||||||
|
(0, BLOCK_N))
|
||||||
|
return acc, l_i, m_i
|
||||||
|
|
||||||
|
|
||||||
|
@triton.autotune(
|
||||||
|
configs=[
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
"BLOCK_M": 256,
|
||||||
|
"BLOCK_N": 64,
|
||||||
|
"waves_per_eu": 2,
|
||||||
|
"PRE_LOAD_V": False,
|
||||||
|
},
|
||||||
|
num_stages=1,
|
||||||
|
num_warps=8,
|
||||||
|
),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
"BLOCK_M": 128,
|
||||||
|
"BLOCK_N": 128,
|
||||||
|
"waves_per_eu": 2,
|
||||||
|
"PRE_LOAD_V": False,
|
||||||
|
},
|
||||||
|
num_stages=1,
|
||||||
|
num_warps=4,
|
||||||
|
),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
"BLOCK_M": 256,
|
||||||
|
"BLOCK_N": 128,
|
||||||
|
"waves_per_eu": 2,
|
||||||
|
"PRE_LOAD_V": False,
|
||||||
|
},
|
||||||
|
num_stages=1,
|
||||||
|
num_warps=8,
|
||||||
|
),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
"BLOCK_M": 128,
|
||||||
|
"BLOCK_N": 64,
|
||||||
|
"waves_per_eu": 3,
|
||||||
|
"PRE_LOAD_V": True,
|
||||||
|
},
|
||||||
|
num_stages=1,
|
||||||
|
num_warps=4,
|
||||||
|
),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
"BLOCK_M": 128,
|
||||||
|
"BLOCK_N": 64,
|
||||||
|
"waves_per_eu": 3,
|
||||||
|
"PRE_LOAD_V": False,
|
||||||
|
},
|
||||||
|
num_stages=1,
|
||||||
|
num_warps=4,
|
||||||
|
),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
"BLOCK_M": 64,
|
||||||
|
"BLOCK_N": 64,
|
||||||
|
"waves_per_eu": 4,
|
||||||
|
"PRE_LOAD_V": False,
|
||||||
|
},
|
||||||
|
num_stages=1,
|
||||||
|
num_warps=8,
|
||||||
|
),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
"BLOCK_M": 32,
|
||||||
|
"BLOCK_N": 32,
|
||||||
|
"waves_per_eu": 4,
|
||||||
|
"PRE_LOAD_V": False,
|
||||||
|
},
|
||||||
|
num_stages=1,
|
||||||
|
num_warps=8,
|
||||||
|
),
|
||||||
|
# TODO: This config fails with head_size not pow2 with data mismatches.
|
||||||
|
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1,
|
||||||
|
# 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
|
||||||
|
triton.Config(
|
||||||
|
{
|
||||||
|
"BLOCK_M": 16,
|
||||||
|
"BLOCK_N": 16,
|
||||||
|
"waves_per_eu": 1,
|
||||||
|
"PRE_LOAD_V": False,
|
||||||
|
},
|
||||||
|
num_stages=1,
|
||||||
|
num_warps=4,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
key=["hq", "hk", "IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"],
|
||||||
|
)
|
||||||
|
@triton.jit
|
||||||
|
def attn_fwd(
|
||||||
|
Q,
|
||||||
|
K,
|
||||||
|
V,
|
||||||
|
bias,
|
||||||
|
sm_scale,
|
||||||
|
L,
|
||||||
|
Out,
|
||||||
|
stride_qz,
|
||||||
|
stride_qh,
|
||||||
|
stride_qm,
|
||||||
|
stride_qk,
|
||||||
|
stride_kz,
|
||||||
|
stride_kh,
|
||||||
|
stride_kn,
|
||||||
|
stride_kk,
|
||||||
|
stride_vz,
|
||||||
|
stride_vh,
|
||||||
|
stride_vk,
|
||||||
|
stride_vn,
|
||||||
|
stride_oz,
|
||||||
|
stride_oh,
|
||||||
|
stride_om,
|
||||||
|
stride_on,
|
||||||
|
stride_bz,
|
||||||
|
stride_bh,
|
||||||
|
stride_bm,
|
||||||
|
stride_bn,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
dropout_p,
|
||||||
|
philox_seed,
|
||||||
|
philox_offset_base,
|
||||||
|
encoded_softmax,
|
||||||
|
hq,
|
||||||
|
hk,
|
||||||
|
ACTUAL_BLOCK_DMODEL: tl.constexpr,
|
||||||
|
MAX_SEQLENS_Q: tl.constexpr,
|
||||||
|
MAX_SEQLENS_K: tl.constexpr,
|
||||||
|
VARLEN: tl.constexpr,
|
||||||
|
IS_CAUSAL: tl.constexpr,
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
PRE_LOAD_V: tl.constexpr,
|
||||||
|
BIAS_TYPE: tl.constexpr,
|
||||||
|
ENABLE_DROPOUT: tl.constexpr,
|
||||||
|
RETURN_ENCODED_SOFTMAX: tl.constexpr,
|
||||||
|
):
|
||||||
|
start_m = tl.program_id(0)
|
||||||
|
off_h_q = tl.program_id(1)
|
||||||
|
off_z = tl.program_id(2)
|
||||||
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
|
offs_n = tl.arange(0, BLOCK_N)
|
||||||
|
if VARLEN:
|
||||||
|
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
|
||||||
|
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
|
||||||
|
seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start
|
||||||
|
# We have a one-size-fits-all grid in id(0). Some seqlens might be too
|
||||||
|
# small for all start_m so for those we return early.
|
||||||
|
if start_m * BLOCK_M > seqlen_q:
|
||||||
|
return
|
||||||
|
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
|
||||||
|
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
|
||||||
|
seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start
|
||||||
|
else:
|
||||||
|
cu_seqlens_q_start = 0
|
||||||
|
cu_seqlens_k_start = 0
|
||||||
|
seqlen_q = MAX_SEQLENS_Q
|
||||||
|
seqlen_k = MAX_SEQLENS_K
|
||||||
|
|
||||||
|
# Now we compute whether we need to exit early due to causal masking.
|
||||||
|
# This is because for seqlen_q > seqlen_k, M rows of the attn scores
|
||||||
|
# are completely masked, resulting in 0s written to the output, and
|
||||||
|
# inf written to LSE. We don't need to do any GEMMs in this case.
|
||||||
|
# This block of code determines what N is, and if this WG is operating
|
||||||
|
# on those M rows.
|
||||||
|
n_blocks = cdiv_fn(seqlen_k, BLOCK_N)
|
||||||
|
if IS_CAUSAL:
|
||||||
|
# If seqlen_q == seqlen_k, the attn scores are a square matrix.
|
||||||
|
# If seqlen_q != seqlen_k, attn scores are rectangular which means
|
||||||
|
# the causal mask boundary is bottom right aligned, and ends at either
|
||||||
|
# the top edge (seqlen_q < seqlen_k) or left edge.
|
||||||
|
# This captures the decrease in n_blocks if we have a rectangular attn
|
||||||
|
# matrix
|
||||||
|
n_blocks_seqlen = cdiv_fn(
|
||||||
|
(start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N)
|
||||||
|
# This is what adjusts the block_max for the current WG, only
|
||||||
|
# if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks
|
||||||
|
n_blocks = min(n_blocks, n_blocks_seqlen)
|
||||||
|
# If we have no blocks after adjusting for seqlen deltas, this WG is
|
||||||
|
# part of the blocks that are all 0. We exit early.
|
||||||
|
if n_blocks <= 0:
|
||||||
|
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
|
||||||
|
off_h_q * stride_oh)
|
||||||
|
O_block_ptr = tl.make_block_ptr(
|
||||||
|
base=Out + o_offset,
|
||||||
|
shape=(seqlen_q, BLOCK_DMODEL),
|
||||||
|
strides=(stride_om, stride_on),
|
||||||
|
offsets=(start_m * BLOCK_M, 0),
|
||||||
|
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||||
|
order=(1, 0),
|
||||||
|
)
|
||||||
|
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
|
||||||
|
# We still need to write 0s to the result
|
||||||
|
# tl.store(O_block_ptr,
|
||||||
|
# acc.to(Out.type.element_ty), boundary_check=(0,1))
|
||||||
|
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
|
||||||
|
# + offs_m
|
||||||
|
# We store inf to LSE, not -inf because in the bwd pass,
|
||||||
|
# we subtract this
|
||||||
|
# from qk which makes it -inf, such that exp(qk - inf) = 0
|
||||||
|
# for these masked blocks.
|
||||||
|
# l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
|
||||||
|
# tl.store(l_ptrs, l)
|
||||||
|
# TODO: Should dropout and return encoded softmax be handled here?
|
||||||
|
return
|
||||||
|
|
||||||
|
is_mqa = hq != hk
|
||||||
|
off_h_k = off_h_q % hk if is_mqa else off_h_q
|
||||||
|
n_extra_tokens = 0
|
||||||
|
if seqlen_k < BLOCK_N:
|
||||||
|
n_extra_tokens = BLOCK_N - seqlen_k
|
||||||
|
elif seqlen_k % BLOCK_N:
|
||||||
|
n_extra_tokens = seqlen_k % BLOCK_N
|
||||||
|
padded_head = ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL
|
||||||
|
|
||||||
|
# Compute pointers for all the tensors used in this kernel.
|
||||||
|
q_offset = (off_z * stride_qz + off_h_q * stride_qh +
|
||||||
|
cu_seqlens_q_start * stride_qm)
|
||||||
|
Q_block_ptr = tl.make_block_ptr(
|
||||||
|
base=Q + q_offset,
|
||||||
|
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
||||||
|
strides=(stride_qm, stride_qk),
|
||||||
|
offsets=(start_m * BLOCK_M, 0),
|
||||||
|
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||||
|
order=(1, 0),
|
||||||
|
)
|
||||||
|
k_offset = (off_z * stride_kz + off_h_k * stride_kh +
|
||||||
|
cu_seqlens_k_start * stride_kn)
|
||||||
|
K_block_ptr = tl.make_block_ptr(
|
||||||
|
base=K + k_offset,
|
||||||
|
shape=(ACTUAL_BLOCK_DMODEL, seqlen_k),
|
||||||
|
strides=(stride_kk, stride_kn),
|
||||||
|
offsets=(0, 0),
|
||||||
|
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||||
|
order=(0, 1),
|
||||||
|
)
|
||||||
|
v_offset = (off_z * stride_vz + off_h_k * stride_vh +
|
||||||
|
cu_seqlens_k_start * stride_vk)
|
||||||
|
V_block_ptr = tl.make_block_ptr(
|
||||||
|
base=V + v_offset,
|
||||||
|
shape=(seqlen_k, ACTUAL_BLOCK_DMODEL),
|
||||||
|
strides=(stride_vk, stride_vn),
|
||||||
|
offsets=(0, 0),
|
||||||
|
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||||
|
order=(1, 0),
|
||||||
|
)
|
||||||
|
if BIAS_TYPE != 0:
|
||||||
|
bias_ptr = tl.make_block_ptr(
|
||||||
|
base=bias + off_h_q * stride_bh,
|
||||||
|
shape=(seqlen_q, seqlen_k),
|
||||||
|
strides=(stride_bm, stride_bn),
|
||||||
|
offsets=(start_m * BLOCK_M, 0),
|
||||||
|
block_shape=(BLOCK_M, BLOCK_N),
|
||||||
|
order=(1, 0),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
bias_ptr = None
|
||||||
|
if ENABLE_DROPOUT:
|
||||||
|
batch_philox_offset = philox_offset_base \
|
||||||
|
+ (off_z * hq + off_h_q) \
|
||||||
|
* seqlen_q * seqlen_k
|
||||||
|
else:
|
||||||
|
batch_philox_offset = 0
|
||||||
|
# We can ask to return the dropout mask without actually doing any dropout.
|
||||||
|
# In this case, we return an invalid pointer so indicate the mask is not i
|
||||||
|
# valid.
|
||||||
|
# TODO: Fix encoded softmax. It currently uses just h_q in the base offset.
|
||||||
|
if RETURN_ENCODED_SOFTMAX:
|
||||||
|
encoded_softmax_block_ptr = tl.make_block_ptr(
|
||||||
|
base=encoded_softmax + off_h_q * seqlen_q * seqlen_k,
|
||||||
|
shape=(seqlen_q, seqlen_k),
|
||||||
|
strides=(seqlen_k, 1),
|
||||||
|
offsets=(start_m * BLOCK_M, 0),
|
||||||
|
block_shape=(BLOCK_M, BLOCK_N),
|
||||||
|
order=(1, 0),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
encoded_softmax_block_ptr = 0
|
||||||
|
# initialize pointer to m and l
|
||||||
|
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
|
||||||
|
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
|
||||||
|
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||||
|
# scale sm_scale by log_2(e) and use 2^x in the loop as we do not
|
||||||
|
# have native e^x support in HW.
|
||||||
|
qk_scale = sm_scale * 1.44269504089
|
||||||
|
# Q is loaded once at the beginning and shared by all N blocks.
|
||||||
|
q = load_fn(Q_block_ptr, True, padded_head, "zero")
|
||||||
|
q = (q * qk_scale).to(Q_block_ptr.type.element_ty)
|
||||||
|
|
||||||
|
# Here we compute how many full and masked blocks we have.
|
||||||
|
padded_block_k = n_extra_tokens != 0
|
||||||
|
is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0)
|
||||||
|
if IS_CAUSAL:
|
||||||
|
# There are always at least BLOCK_M // BLOCK_N masked blocks.
|
||||||
|
# Additionally there might be one more due to dissimilar seqlens.
|
||||||
|
masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn)
|
||||||
|
else:
|
||||||
|
# Padding on Q does not need to be masked in the FA loop.
|
||||||
|
masked_blocks = padded_block_k
|
||||||
|
# if IS_CAUSAL, not is_modulo_mn does not always result in an additional
|
||||||
|
# block. In this case we might exceed n_blocks so pick the min.
|
||||||
|
masked_blocks = min(masked_blocks, n_blocks)
|
||||||
|
n_full_blocks = n_blocks - masked_blocks
|
||||||
|
block_min = 0
|
||||||
|
block_max = n_blocks * BLOCK_N
|
||||||
|
# Compute for full blocks. Here we set causal to false regardless of its
|
||||||
|
# value because there is no masking. Similarly we do not need padding.
|
||||||
|
if n_full_blocks > 0:
|
||||||
|
block_max = (n_blocks - masked_blocks) * BLOCK_N
|
||||||
|
acc, l_i, m_i = _attn_fwd_inner(
|
||||||
|
acc,
|
||||||
|
l_i,
|
||||||
|
m_i,
|
||||||
|
q,
|
||||||
|
K_block_ptr,
|
||||||
|
V_block_ptr,
|
||||||
|
start_m,
|
||||||
|
seqlen_k,
|
||||||
|
dropout_p,
|
||||||
|
philox_seed,
|
||||||
|
batch_philox_offset,
|
||||||
|
encoded_softmax_block_ptr,
|
||||||
|
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
|
||||||
|
block_min,
|
||||||
|
block_max,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
bias_ptr,
|
||||||
|
# IS_CAUSAL, ....
|
||||||
|
False,
|
||||||
|
BLOCK_M,
|
||||||
|
BLOCK_DMODEL,
|
||||||
|
BLOCK_N,
|
||||||
|
offs_m,
|
||||||
|
offs_n,
|
||||||
|
# _, MASK_STEPS, ...
|
||||||
|
PRE_LOAD_V,
|
||||||
|
False,
|
||||||
|
ENABLE_DROPOUT,
|
||||||
|
RETURN_ENCODED_SOFTMAX,
|
||||||
|
padded_head,
|
||||||
|
)
|
||||||
|
block_min = block_max
|
||||||
|
block_max = n_blocks * BLOCK_N
|
||||||
|
|
||||||
|
tl.debug_barrier()
|
||||||
|
# Remaining blocks, if any, are full / not masked.
|
||||||
|
if masked_blocks > 0:
|
||||||
|
offs_n_causal = offs_n + (seqlen_q - seqlen_k) if IS_CAUSAL else 0
|
||||||
|
K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks * BLOCK_N))
|
||||||
|
V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks * BLOCK_N, 0))
|
||||||
|
if bias_ptr is not None:
|
||||||
|
bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks * BLOCK_N))
|
||||||
|
if RETURN_ENCODED_SOFTMAX:
|
||||||
|
encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr,
|
||||||
|
(0, n_full_blocks))
|
||||||
|
acc, l_i, m_i = _attn_fwd_inner(
|
||||||
|
acc,
|
||||||
|
l_i,
|
||||||
|
m_i,
|
||||||
|
q,
|
||||||
|
K_block_ptr,
|
||||||
|
V_block_ptr,
|
||||||
|
start_m,
|
||||||
|
seqlen_k,
|
||||||
|
dropout_p,
|
||||||
|
philox_seed,
|
||||||
|
batch_philox_offset,
|
||||||
|
encoded_softmax_block_ptr,
|
||||||
|
block_min,
|
||||||
|
block_max,
|
||||||
|
offs_n_causal,
|
||||||
|
masked_blocks,
|
||||||
|
n_extra_tokens,
|
||||||
|
bias_ptr,
|
||||||
|
IS_CAUSAL,
|
||||||
|
BLOCK_M,
|
||||||
|
BLOCK_DMODEL,
|
||||||
|
BLOCK_N,
|
||||||
|
offs_m,
|
||||||
|
offs_n,
|
||||||
|
# _, MASK_STEPS, ...
|
||||||
|
PRE_LOAD_V,
|
||||||
|
True,
|
||||||
|
ENABLE_DROPOUT,
|
||||||
|
RETURN_ENCODED_SOFTMAX,
|
||||||
|
padded_head,
|
||||||
|
)
|
||||||
|
# epilogue
|
||||||
|
acc = acc / l_i[:, None]
|
||||||
|
if ENABLE_DROPOUT:
|
||||||
|
acc = acc / (1 - dropout_p)
|
||||||
|
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
|
||||||
|
# then we have one block with a row of all NaNs which come from computing
|
||||||
|
# softmax over a row of all -infs (-inf - inf = NaN). We check for that here
|
||||||
|
# and store 0s where there are NaNs as these rows should've been zeroed out.
|
||||||
|
end_m_idx = (start_m + 1) * BLOCK_M
|
||||||
|
start_m_idx = start_m * BLOCK_M
|
||||||
|
causal_start_idx = seqlen_q - seqlen_k
|
||||||
|
acc = acc.to(Out.type.element_ty)
|
||||||
|
if IS_CAUSAL: # noqa: SIM102
|
||||||
|
if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx:
|
||||||
|
out_mask_boundary = tl.full((BLOCK_DMODEL, ),
|
||||||
|
causal_start_idx,
|
||||||
|
dtype=tl.int32)
|
||||||
|
mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M)
|
||||||
|
out_ptrs_mask = (mask_m_offsets[:, None] >=
|
||||||
|
out_mask_boundary[None, :])
|
||||||
|
z = 0.0
|
||||||
|
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
|
||||||
|
# write back LSE
|
||||||
|
# l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
|
||||||
|
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
|
||||||
|
# few rows. This is only true for the last M block. For others,
|
||||||
|
# overflow_size will be -ve
|
||||||
|
# overflow_size = end_m_idx - seqlen_q
|
||||||
|
# if overflow_size > 0:
|
||||||
|
# boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
|
||||||
|
# # This is a > check because mask being 0 blocks the store.
|
||||||
|
# l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
|
||||||
|
# tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
|
||||||
|
# else:
|
||||||
|
# tl.store(l_ptrs, m_i + tl.math.log2(l_i))
|
||||||
|
|
||||||
|
# write back O
|
||||||
|
o_offset = (off_z * stride_oz + cu_seqlens_q_start * stride_om +
|
||||||
|
off_h_q * stride_oh)
|
||||||
|
O_block_ptr = tl.make_block_ptr(
|
||||||
|
base=Out + o_offset,
|
||||||
|
shape=(seqlen_q, ACTUAL_BLOCK_DMODEL),
|
||||||
|
strides=(stride_om, stride_on),
|
||||||
|
offsets=(start_m * BLOCK_M, 0),
|
||||||
|
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||||
|
order=(1, 0),
|
||||||
|
)
|
||||||
|
# Need boundary check on this to make sure the padding from the
|
||||||
|
# Q and KV tensors in both dims are not part of what we store back.
|
||||||
|
# TODO: Do the boundary check optionally.
|
||||||
|
tl.store(O_block_ptr, acc, boundary_check=(0, 1))
|
||||||
|
|
||||||
|
|
||||||
|
def check_args(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
o,
|
||||||
|
varlen=True,
|
||||||
|
max_seqlens=None,
|
||||||
|
cu_seqlens_q=None,
|
||||||
|
cu_seqlens_k=None,
|
||||||
|
):
|
||||||
|
assert q.dim() == k.dim() and q.dim() == v.dim()
|
||||||
|
if varlen:
|
||||||
|
assert q.dim() == 3
|
||||||
|
total_q, nheads_q, head_size = q.shape
|
||||||
|
total_k, nheads_k, _ = k.shape
|
||||||
|
assert cu_seqlens_q is not None
|
||||||
|
assert cu_seqlens_k is not None
|
||||||
|
assert len(cu_seqlens_q) == len(cu_seqlens_k)
|
||||||
|
else:
|
||||||
|
assert q.dim() == 4
|
||||||
|
batch, nheads_q, seqlen_q, head_size = q.shape
|
||||||
|
_, nheads_k, seqlen_k, _ = k.shape
|
||||||
|
assert max_seqlens > 0
|
||||||
|
assert k.shape == v.shape
|
||||||
|
assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1]
|
||||||
|
# TODO: Change assert if we support qkl f8 and v f16
|
||||||
|
assert q.dtype == k.dtype and q.dtype == v.dtype
|
||||||
|
# TODO: Fix assert to check head size <=256 once supported
|
||||||
|
assert head_size <= 128
|
||||||
|
assert o.shape == q.shape
|
||||||
|
assert (nheads_q % nheads_k) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class _attention(torch.autograd.Function):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(
|
||||||
|
ctx,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
o,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlens_q,
|
||||||
|
max_seqlens_k,
|
||||||
|
causal=False,
|
||||||
|
sm_scale=1.0,
|
||||||
|
bias=None,
|
||||||
|
):
|
||||||
|
if o is None:
|
||||||
|
o = torch.empty_like(q, dtype=v.dtype)
|
||||||
|
|
||||||
|
check_args(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
o,
|
||||||
|
varlen=True,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
)
|
||||||
|
if True: # varlen
|
||||||
|
total_q, nheads_q, head_size = q.shape
|
||||||
|
total_k, nheads_k, _ = k.shape
|
||||||
|
batch = len(cu_seqlens_q) - 1
|
||||||
|
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
|
||||||
|
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
|
||||||
|
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
|
||||||
|
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
|
||||||
|
else:
|
||||||
|
batch, seqlen_q, nheads_q, head_size = q.shape
|
||||||
|
_, seqlen_k, nheads_k, _ = k.shape
|
||||||
|
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
|
||||||
|
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
|
||||||
|
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
|
||||||
|
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
|
||||||
|
|
||||||
|
# Get closest power of 2 over or equal to 32.
|
||||||
|
unpadded_head_dims = {32, 64, 128}
|
||||||
|
if head_size not in unpadded_head_dims:
|
||||||
|
padded_d_model = None
|
||||||
|
for i in unpadded_head_dims:
|
||||||
|
if i > head_size:
|
||||||
|
padded_d_model = i
|
||||||
|
break
|
||||||
|
assert padded_d_model is not None
|
||||||
|
else:
|
||||||
|
padded_d_model = head_size
|
||||||
|
|
||||||
|
grid = lambda META: (
|
||||||
|
triton.cdiv(max_seqlens_q, META["BLOCK_M"]),
|
||||||
|
nheads_q,
|
||||||
|
batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
encoded_softmax = None
|
||||||
|
|
||||||
|
# Seed the RNG so we get reproducible results for testing.
|
||||||
|
philox_seed = 0x1BF52
|
||||||
|
philox_offset = 0x1D4B42
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
bias_strides = (
|
||||||
|
bias.stride(0),
|
||||||
|
bias.stride(1),
|
||||||
|
bias.stride(2),
|
||||||
|
bias.stride(3),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
bias_strides = (0, 0, 0, 0)
|
||||||
|
|
||||||
|
attn_fwd[grid](
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
bias,
|
||||||
|
sm_scale,
|
||||||
|
None,
|
||||||
|
o,
|
||||||
|
*q_strides,
|
||||||
|
*k_strides,
|
||||||
|
*v_strides,
|
||||||
|
*o_strides,
|
||||||
|
*bias_strides,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
dropout_p=0.0,
|
||||||
|
philox_seed=philox_seed,
|
||||||
|
philox_offset_base=philox_offset,
|
||||||
|
encoded_softmax=encoded_softmax,
|
||||||
|
hq=nheads_q,
|
||||||
|
hk=nheads_k,
|
||||||
|
ACTUAL_BLOCK_DMODEL=head_size,
|
||||||
|
MAX_SEQLENS_Q=max_seqlens_q,
|
||||||
|
MAX_SEQLENS_K=max_seqlens_k,
|
||||||
|
IS_CAUSAL=causal,
|
||||||
|
VARLEN=True,
|
||||||
|
BLOCK_DMODEL=padded_d_model,
|
||||||
|
BIAS_TYPE=0 if bias is None else 1,
|
||||||
|
ENABLE_DROPOUT=False,
|
||||||
|
RETURN_ENCODED_SOFTMAX=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx.grid = grid
|
||||||
|
ctx.sm_scale = sm_scale
|
||||||
|
ctx.BLOCK_DMODEL = head_size
|
||||||
|
ctx.causal = causal
|
||||||
|
ctx.dropout_p = 0.0
|
||||||
|
ctx.philox_seed = philox_seed
|
||||||
|
ctx.philox_offset = philox_offset
|
||||||
|
ctx.encoded_softmax = encoded_softmax
|
||||||
|
ctx.return_encoded_softmax = False
|
||||||
|
return o, encoded_softmax
|
||||||
|
|
||||||
|
|
||||||
|
triton_attention = _attention.apply
|
@ -1,3 +1,4 @@
|
|||||||
|
import enum
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
@ -10,46 +11,68 @@ from vllm.utils import is_cpu, is_hip
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class _Backend(enum.Enum):
|
||||||
|
FLASH_ATTN = enum.auto()
|
||||||
|
XFORMERS = enum.auto()
|
||||||
|
ROCM_FLASH = enum.auto()
|
||||||
|
TORCH_SDPA = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=None)
|
@lru_cache(maxsize=None)
|
||||||
def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
|
def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
|
||||||
if _can_use_flash_attn(dtype):
|
backend = _which_attn_to_use(dtype)
|
||||||
|
if backend == _Backend.FLASH_ATTN:
|
||||||
logger.info("Using FlashAttention backend.")
|
logger.info("Using FlashAttention backend.")
|
||||||
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
from vllm.attention.backends.flash_attn import ( # noqa: F401
|
||||||
FlashAttentionBackend)
|
FlashAttentionBackend)
|
||||||
return FlashAttentionBackend
|
return FlashAttentionBackend
|
||||||
elif is_cpu():
|
elif backend == _Backend.XFORMERS:
|
||||||
logger.info("Using Torch SDPA backend.")
|
|
||||||
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
|
|
||||||
return TorchSDPABackend
|
|
||||||
else:
|
|
||||||
logger.info("Using XFormers backend.")
|
logger.info("Using XFormers backend.")
|
||||||
from vllm.attention.backends.xformers import ( # noqa: F401
|
from vllm.attention.backends.xformers import ( # noqa: F401
|
||||||
XFormersBackend)
|
XFormersBackend)
|
||||||
return XFormersBackend
|
return XFormersBackend
|
||||||
|
elif backend == _Backend.ROCM_FLASH:
|
||||||
|
logger.info("Using ROCmFlashAttention backend.")
|
||||||
|
from vllm.attention.backends.rocm_flash_attn import ( # noqa: F401
|
||||||
|
ROCmFlashAttentionBackend)
|
||||||
|
return ROCmFlashAttentionBackend
|
||||||
|
elif backend == _Backend.TORCH_SDPA:
|
||||||
|
logger.info("Using Torch SDPA backend.")
|
||||||
|
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
|
||||||
|
return TorchSDPABackend
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid attention backend.")
|
||||||
|
|
||||||
|
|
||||||
def _can_use_flash_attn(dtype: torch.dtype) -> bool:
|
def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
|
||||||
|
"""Returns which flash attention backend to use."""
|
||||||
|
if is_cpu():
|
||||||
|
return _Backend.TORCH_SDPA
|
||||||
|
|
||||||
if is_hip():
|
if is_hip():
|
||||||
# AMD GPUs.
|
# AMD GPUs.
|
||||||
logger.info("Cannot use FlashAttention backend for AMD GPUs.")
|
if torch.cuda.get_device_capability()[0] != 9:
|
||||||
return False
|
# not Instinct series GPUs.
|
||||||
if is_cpu():
|
logger.info("flash_atten is not supported on NAVI GPUs.")
|
||||||
return False
|
return _Backend.ROCM_FLASH
|
||||||
|
|
||||||
|
# NVIDIA GPUs.
|
||||||
if torch.cuda.get_device_capability()[0] < 8:
|
if torch.cuda.get_device_capability()[0] < 8:
|
||||||
# Volta and Turing NVIDIA GPUs.
|
# Volta and Turing NVIDIA GPUs.
|
||||||
logger.info("Cannot use FlashAttention backend for Volta and Turing "
|
logger.info("Cannot use FlashAttention backend for Volta and Turing "
|
||||||
"GPUs.")
|
"GPUs.")
|
||||||
return False
|
return _Backend.XFORMERS
|
||||||
|
|
||||||
if dtype not in (torch.float16, torch.bfloat16):
|
if dtype not in (torch.float16, torch.bfloat16):
|
||||||
logger.info("Cannot use FlashAttention backend for dtype other than "
|
logger.info("Cannot use FlashAttention backend for dtype other than "
|
||||||
"torch.float16 or torch.bfloat16.")
|
"torch.float16 or torch.bfloat16.")
|
||||||
return False
|
return _Backend.XFORMERS
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import flash_attn # noqa: F401
|
import flash_attn # noqa: F401
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Cannot use FlashAttention because the package is not found. "
|
"Cannot use FlashAttention backend because the flash_attn package "
|
||||||
"Please install it for better performance.")
|
"is not found. Please install it for better performance.")
|
||||||
return False
|
return _Backend.XFORMERS
|
||||||
return True
|
return _Backend.FLASH_ATTN
|
||||||
|
Loading…
x
Reference in New Issue
Block a user