[Misc] Support attention logits soft-capping with flash-attn (#7022)
This commit is contained in:
parent
562e580abc
commit
805a8a75f2
@ -8,4 +8,4 @@ torch == 2.4.0
|
||||
# These must be updated alongside torch
|
||||
torchvision == 0.19 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
|
||||
xformers == 0.0.27.post2 # Requires PyTorch 2.4.0
|
||||
vllm-flash-attn == 2.6.0 # Requires PyTorch 2.4.0
|
||||
vllm-flash-attn == 2.6.1 # Requires PyTorch 2.4.0
|
||||
|
@ -20,6 +20,7 @@ def ref_paged_attn(
|
||||
block_tables: torch.Tensor,
|
||||
scale: float,
|
||||
sliding_window: Optional[int] = None,
|
||||
soft_cap: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
num_seqs = len(query_lens)
|
||||
block_tables = block_tables.cpu().numpy()
|
||||
@ -53,6 +54,8 @@ def ref_paged_attn(
|
||||
(query_len + sliding_window) +
|
||||
1).bool().logical_not()
|
||||
mask |= sliding_window_mask
|
||||
if soft_cap is not None:
|
||||
attn = soft_cap * torch.tanh(attn / soft_cap)
|
||||
attn.masked_fill_(mask, float("-inf"))
|
||||
attn = torch.softmax(attn, dim=-1).to(v.dtype)
|
||||
out = torch.einsum("hqk,khd->qhd", attn, v)
|
||||
@ -68,13 +71,15 @@ def ref_paged_attn(
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@torch.inference_mode
|
||||
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
|
||||
@torch.inference_mode()
|
||||
def test_flash_attn_with_paged_kv(
|
||||
kv_lens: List[int],
|
||||
num_heads: Tuple[int, int],
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: Optional[float],
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
@ -108,6 +113,7 @@ def test_flash_attn_with_paged_kv(
|
||||
causal=True,
|
||||
block_table=block_tables,
|
||||
cache_seqlens=kv_lens_tensor,
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
).squeeze(1)
|
||||
|
||||
ref_output = ref_paged_attn(
|
||||
@ -118,6 +124,7 @@ def test_flash_attn_with_paged_kv(
|
||||
kv_lens=kv_lens,
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
soft_cap=soft_cap,
|
||||
)
|
||||
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
|
||||
f"{torch.max(torch.abs(output - ref_output))}"
|
||||
@ -129,7 +136,8 @@ def test_flash_attn_with_paged_kv(
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("sliding_window", [None])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@torch.inference_mode
|
||||
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
|
||||
@torch.inference_mode()
|
||||
def test_varlen_with_paged_kv(
|
||||
seq_lens: List[Tuple[int, int]],
|
||||
num_heads: Tuple[int, int],
|
||||
@ -137,6 +145,7 @@ def test_varlen_with_paged_kv(
|
||||
sliding_window: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: Optional[float],
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
torch.cuda.manual_seed_all(0)
|
||||
@ -163,10 +172,6 @@ def test_varlen_with_paged_kv(
|
||||
head_size,
|
||||
dtype=dtype)
|
||||
value_cache = torch.randn_like(key_cache)
|
||||
# Normalize the scale of the key and value caches to mitigate
|
||||
# numerical instability.
|
||||
key_cache /= head_size**0.5
|
||||
value_cache /= head_size**0.5
|
||||
cu_query_lens = torch.tensor([0] + query_lens,
|
||||
dtype=torch.int32).cumsum(dim=0,
|
||||
dtype=torch.int32)
|
||||
@ -192,6 +197,7 @@ def test_varlen_with_paged_kv(
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
block_table=block_tables,
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
)
|
||||
|
||||
ref_output = ref_paged_attn(
|
||||
@ -203,6 +209,7 @@ def test_varlen_with_paged_kv(
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
sliding_window=sliding_window,
|
||||
soft_cap=soft_cap,
|
||||
)
|
||||
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
|
||||
f"{torch.max(torch.abs(output - ref_output))}"
|
||||
|
@ -150,6 +150,7 @@ class AttentionImpl(ABC, Generic[T]):
|
||||
sliding_window: Optional[int] = None,
|
||||
kv_cache_dtype: str = "auto",
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -283,12 +283,15 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> None:
|
||||
assert blocksparse_params is not None
|
||||
assert alibi_slopes is None, ValueError(
|
||||
"Alibi not support for blocksparse flash attention.")
|
||||
assert sliding_window is None, ValueError(
|
||||
"sliding_window is invalid for blocksparse attention.")
|
||||
assert logits_soft_cap is None, ValueError(
|
||||
"logits_soft_cap is invalid for blocksparse attention.")
|
||||
|
||||
if "num_heads" not in blocksparse_params:
|
||||
blocksparse_params["num_heads"] = num_heads
|
||||
|
@ -288,15 +288,6 @@ class FlashAttentionMetadataBuilder(
|
||||
device = self.runner.device
|
||||
use_captured_graph = cuda_graph_pad_size != -1
|
||||
|
||||
logits_soft_cap = getattr(self.runner.model_config.hf_config,
|
||||
"attn_logit_softcapping", None)
|
||||
if logits_soft_cap is not None:
|
||||
raise ValueError(
|
||||
"Please use Flashinfer backend for models with logits_soft_cap"
|
||||
" (i.e., Gemma-2). Otherwise, the output might be wrong."
|
||||
" Set Flashinfer backend by "
|
||||
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")
|
||||
|
||||
max_query_len = max(query_lens)
|
||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||
@ -405,8 +396,10 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> None:
|
||||
assert blocksparse_params is None, ValueError(
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"FlashAttention does not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
@ -418,6 +411,10 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
self.sliding_window = ((sliding_window, sliding_window)
|
||||
if sliding_window is not None else (-1, -1))
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
if logits_soft_cap is None:
|
||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||
logits_soft_cap = 0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
@ -525,6 +522,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
causal=True,
|
||||
window_size=self.sliding_window,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
softcap=self.logits_soft_cap,
|
||||
)
|
||||
assert output[:num_prefill_tokens].shape == out.shape
|
||||
output[:num_prefill_tokens] = out
|
||||
@ -544,6 +542,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
block_table=prefill_meta.block_tables,
|
||||
softcap=self.logits_soft_cap,
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
|
@ -116,8 +116,6 @@ class FlashInferMetadata(AttentionMetadata):
|
||||
# The data type of the paged kv cache
|
||||
data_type: torch.dtype = None
|
||||
device: torch.device = torch.device("cuda")
|
||||
# Only used by gemma2 model
|
||||
logits_soft_cap: Optional[float] = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Refer to
|
||||
@ -391,9 +389,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
logits_soft_cap = getattr(self.runner.model_config.hf_config,
|
||||
"attn_logit_softcapping", None)
|
||||
|
||||
if len(self.paged_kv_indptr) > 0:
|
||||
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
|
||||
device="cpu",
|
||||
@ -430,8 +425,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
||||
query_start_loc=query_start_loc,
|
||||
device=device,
|
||||
data_type=kv_cache_dtype,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
logits_soft_cap=logits_soft_cap)
|
||||
use_cuda_graph=use_captured_graph)
|
||||
|
||||
|
||||
class FlashInferImpl(AttentionImpl):
|
||||
@ -446,6 +440,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
@ -458,6 +453,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
raise ValueError("Sliding window is not supported in FlashInfer.")
|
||||
self.sliding_window = (-1, -1)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
@ -532,7 +528,7 @@ class FlashInferImpl(AttentionImpl):
|
||||
output = prefill_meta.prefill_wrapper.forward(
|
||||
query,
|
||||
kv_cache,
|
||||
logits_soft_cap=attn_metadata.logits_soft_cap,
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
causal=True)
|
||||
else:
|
||||
assert attn_metadata.decode_metadata is not None
|
||||
@ -541,5 +537,5 @@ class FlashInferImpl(AttentionImpl):
|
||||
query,
|
||||
kv_cache,
|
||||
sm_scale=self.scale,
|
||||
logits_soft_cap=attn_metadata.logits_soft_cap)
|
||||
logits_soft_cap=self.logits_soft_cap)
|
||||
return output.view(num_tokens, hidden_size)
|
||||
|
@ -105,9 +105,13 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> None:
|
||||
assert blocksparse_params is None, ValueError(
|
||||
"Torch SPDA does not support block-sparse attention.")
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"IPEX backend does not support block-sparse attention.")
|
||||
if logits_soft_cap is not None:
|
||||
raise ValueError("IPEX backend does not support logits_soft_cap.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
|
@ -91,6 +91,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
@ -109,6 +110,9 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
raise NotImplementedError("FP8 KV cache dtype is not supported.")
|
||||
if blocksparse_params is not None:
|
||||
raise NotImplementedError("Blocksparse is not supported.")
|
||||
if logits_soft_cap is not None:
|
||||
raise NotImplementedError(
|
||||
"Attention logits soft-capping is not supported.")
|
||||
|
||||
if torch_xla.tpu.version() < 4:
|
||||
raise NotImplementedError("TPU version must be 4 or higher.")
|
||||
|
@ -244,9 +244,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> None:
|
||||
assert blocksparse_params is None, ValueError(
|
||||
"ROCFlashAttention does not support blocksparse attention.")
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"ROCmFlashAttention does not support blocksparse attention.")
|
||||
if logits_soft_cap is not None:
|
||||
raise ValueError(
|
||||
"ROCmFlashAttention does not support attention logits soft "
|
||||
"capping.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
|
@ -109,9 +109,13 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> None:
|
||||
assert blocksparse_params is None, ValueError(
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"Torch SPDA does not support block-sparse attention.")
|
||||
if logits_soft_cap is not None:
|
||||
raise ValueError("Torch SPDA does not support logits soft cap.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
|
@ -165,15 +165,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
||||
device = self.runner.device
|
||||
use_captured_graph = cuda_graph_pad_size != -1
|
||||
|
||||
logits_soft_cap = getattr(self.runner.model_config.hf_config,
|
||||
"attn_logit_softcapping", None)
|
||||
if logits_soft_cap is not None:
|
||||
raise ValueError(
|
||||
"Please use Flashinfer backend for models with logits_soft_cap "
|
||||
"(i.e., Gemma-2). Otherwise, the output might be wrong. "
|
||||
"Set Flashinfer backend by "
|
||||
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")
|
||||
|
||||
max_query_len = max(query_lens)
|
||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||
|
@ -408,9 +408,14 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
) -> None:
|
||||
assert blocksparse_params is None, ValueError(
|
||||
"XFormer does not support block-sparse attention.")
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"XFormers does not support block-sparse attention.")
|
||||
if logits_soft_cap is not None:
|
||||
raise ValueError(
|
||||
"XFormers does not support attention logits soft capping.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
|
@ -34,6 +34,7 @@ class Attention(nn.Module):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -82,7 +83,7 @@ class Attention(nn.Module):
|
||||
impl_cls = attn_backend.get_impl_cls()
|
||||
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params)
|
||||
blocksparse_params, logits_soft_cap)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -90,7 +90,8 @@ class Gemma2Attention(nn.Module):
|
||||
max_position_embeddings: int,
|
||||
rope_theta: float,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
attn_logits_soft_cap: Optional[float] = None) -> None:
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.config = config
|
||||
@ -150,7 +151,8 @@ class Gemma2Attention(nn.Module):
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
logits_soft_cap=attn_logits_soft_cap)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -189,6 +191,7 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
rope_theta=config.rope_theta,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
attn_logits_soft_cap=config.attn_logit_softcapping,
|
||||
)
|
||||
self.hidden_size = config.hidden_size
|
||||
self.mlp = Gemma2MLP(
|
||||
|
Loading…
x
Reference in New Issue
Block a user