[Misc] Support attention logits soft-capping with flash-attn (#7022)
This commit is contained in:
parent
562e580abc
commit
805a8a75f2
@ -8,4 +8,4 @@ torch == 2.4.0
|
|||||||
# These must be updated alongside torch
|
# These must be updated alongside torch
|
||||||
torchvision == 0.19 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
|
torchvision == 0.19 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
|
||||||
xformers == 0.0.27.post2 # Requires PyTorch 2.4.0
|
xformers == 0.0.27.post2 # Requires PyTorch 2.4.0
|
||||||
vllm-flash-attn == 2.6.0 # Requires PyTorch 2.4.0
|
vllm-flash-attn == 2.6.1 # Requires PyTorch 2.4.0
|
||||||
|
@ -20,6 +20,7 @@ def ref_paged_attn(
|
|||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
scale: float,
|
scale: float,
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
|
soft_cap: Optional[float] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
num_seqs = len(query_lens)
|
num_seqs = len(query_lens)
|
||||||
block_tables = block_tables.cpu().numpy()
|
block_tables = block_tables.cpu().numpy()
|
||||||
@ -53,6 +54,8 @@ def ref_paged_attn(
|
|||||||
(query_len + sliding_window) +
|
(query_len + sliding_window) +
|
||||||
1).bool().logical_not()
|
1).bool().logical_not()
|
||||||
mask |= sliding_window_mask
|
mask |= sliding_window_mask
|
||||||
|
if soft_cap is not None:
|
||||||
|
attn = soft_cap * torch.tanh(attn / soft_cap)
|
||||||
attn.masked_fill_(mask, float("-inf"))
|
attn.masked_fill_(mask, float("-inf"))
|
||||||
attn = torch.softmax(attn, dim=-1).to(v.dtype)
|
attn = torch.softmax(attn, dim=-1).to(v.dtype)
|
||||||
out = torch.einsum("hqk,khd->qhd", attn, v)
|
out = torch.einsum("hqk,khd->qhd", attn, v)
|
||||||
@ -68,13 +71,15 @@ def ref_paged_attn(
|
|||||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@torch.inference_mode
|
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
|
||||||
|
@torch.inference_mode()
|
||||||
def test_flash_attn_with_paged_kv(
|
def test_flash_attn_with_paged_kv(
|
||||||
kv_lens: List[int],
|
kv_lens: List[int],
|
||||||
num_heads: Tuple[int, int],
|
num_heads: Tuple[int, int],
|
||||||
head_size: int,
|
head_size: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
|
soft_cap: Optional[float],
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
torch.cuda.manual_seed_all(0)
|
torch.cuda.manual_seed_all(0)
|
||||||
@ -108,6 +113,7 @@ def test_flash_attn_with_paged_kv(
|
|||||||
causal=True,
|
causal=True,
|
||||||
block_table=block_tables,
|
block_table=block_tables,
|
||||||
cache_seqlens=kv_lens_tensor,
|
cache_seqlens=kv_lens_tensor,
|
||||||
|
softcap=soft_cap if soft_cap is not None else 0,
|
||||||
).squeeze(1)
|
).squeeze(1)
|
||||||
|
|
||||||
ref_output = ref_paged_attn(
|
ref_output = ref_paged_attn(
|
||||||
@ -118,6 +124,7 @@ def test_flash_attn_with_paged_kv(
|
|||||||
kv_lens=kv_lens,
|
kv_lens=kv_lens,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
|
soft_cap=soft_cap,
|
||||||
)
|
)
|
||||||
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
|
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
|
||||||
f"{torch.max(torch.abs(output - ref_output))}"
|
f"{torch.max(torch.abs(output - ref_output))}"
|
||||||
@ -129,7 +136,8 @@ def test_flash_attn_with_paged_kv(
|
|||||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||||
@pytest.mark.parametrize("sliding_window", [None])
|
@pytest.mark.parametrize("sliding_window", [None])
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@torch.inference_mode
|
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
|
||||||
|
@torch.inference_mode()
|
||||||
def test_varlen_with_paged_kv(
|
def test_varlen_with_paged_kv(
|
||||||
seq_lens: List[Tuple[int, int]],
|
seq_lens: List[Tuple[int, int]],
|
||||||
num_heads: Tuple[int, int],
|
num_heads: Tuple[int, int],
|
||||||
@ -137,6 +145,7 @@ def test_varlen_with_paged_kv(
|
|||||||
sliding_window: Optional[int],
|
sliding_window: Optional[int],
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
|
soft_cap: Optional[float],
|
||||||
) -> None:
|
) -> None:
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
torch.cuda.manual_seed_all(0)
|
torch.cuda.manual_seed_all(0)
|
||||||
@ -163,10 +172,6 @@ def test_varlen_with_paged_kv(
|
|||||||
head_size,
|
head_size,
|
||||||
dtype=dtype)
|
dtype=dtype)
|
||||||
value_cache = torch.randn_like(key_cache)
|
value_cache = torch.randn_like(key_cache)
|
||||||
# Normalize the scale of the key and value caches to mitigate
|
|
||||||
# numerical instability.
|
|
||||||
key_cache /= head_size**0.5
|
|
||||||
value_cache /= head_size**0.5
|
|
||||||
cu_query_lens = torch.tensor([0] + query_lens,
|
cu_query_lens = torch.tensor([0] + query_lens,
|
||||||
dtype=torch.int32).cumsum(dim=0,
|
dtype=torch.int32).cumsum(dim=0,
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
@ -192,6 +197,7 @@ def test_varlen_with_paged_kv(
|
|||||||
causal=True,
|
causal=True,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
block_table=block_tables,
|
block_table=block_tables,
|
||||||
|
softcap=soft_cap if soft_cap is not None else 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
ref_output = ref_paged_attn(
|
ref_output = ref_paged_attn(
|
||||||
@ -203,6 +209,7 @@ def test_varlen_with_paged_kv(
|
|||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
sliding_window=sliding_window,
|
sliding_window=sliding_window,
|
||||||
|
soft_cap=soft_cap,
|
||||||
)
|
)
|
||||||
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
|
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
|
||||||
f"{torch.max(torch.abs(output - ref_output))}"
|
f"{torch.max(torch.abs(output - ref_output))}"
|
||||||
|
@ -150,6 +150,7 @@ class AttentionImpl(ABC, Generic[T]):
|
|||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int] = None,
|
||||||
kv_cache_dtype: str = "auto",
|
kv_cache_dtype: str = "auto",
|
||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
|
logits_soft_cap: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -283,12 +283,15 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
|
|||||||
sliding_window: Optional[int],
|
sliding_window: Optional[int],
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
|
logits_soft_cap: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert blocksparse_params is not None
|
assert blocksparse_params is not None
|
||||||
assert alibi_slopes is None, ValueError(
|
assert alibi_slopes is None, ValueError(
|
||||||
"Alibi not support for blocksparse flash attention.")
|
"Alibi not support for blocksparse flash attention.")
|
||||||
assert sliding_window is None, ValueError(
|
assert sliding_window is None, ValueError(
|
||||||
"sliding_window is invalid for blocksparse attention.")
|
"sliding_window is invalid for blocksparse attention.")
|
||||||
|
assert logits_soft_cap is None, ValueError(
|
||||||
|
"logits_soft_cap is invalid for blocksparse attention.")
|
||||||
|
|
||||||
if "num_heads" not in blocksparse_params:
|
if "num_heads" not in blocksparse_params:
|
||||||
blocksparse_params["num_heads"] = num_heads
|
blocksparse_params["num_heads"] = num_heads
|
||||||
|
@ -288,15 +288,6 @@ class FlashAttentionMetadataBuilder(
|
|||||||
device = self.runner.device
|
device = self.runner.device
|
||||||
use_captured_graph = cuda_graph_pad_size != -1
|
use_captured_graph = cuda_graph_pad_size != -1
|
||||||
|
|
||||||
logits_soft_cap = getattr(self.runner.model_config.hf_config,
|
|
||||||
"attn_logit_softcapping", None)
|
|
||||||
if logits_soft_cap is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"Please use Flashinfer backend for models with logits_soft_cap"
|
|
||||||
" (i.e., Gemma-2). Otherwise, the output might be wrong."
|
|
||||||
" Set Flashinfer backend by "
|
|
||||||
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")
|
|
||||||
|
|
||||||
max_query_len = max(query_lens)
|
max_query_len = max(query_lens)
|
||||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||||
@ -405,9 +396,11 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
sliding_window: Optional[int],
|
sliding_window: Optional[int],
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
|
logits_soft_cap: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert blocksparse_params is None, ValueError(
|
if blocksparse_params is not None:
|
||||||
"FlashAttention does not support block-sparse attention.")
|
raise ValueError(
|
||||||
|
"FlashAttention does not support block-sparse attention.")
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
@ -418,6 +411,10 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
self.sliding_window = ((sliding_window, sliding_window)
|
self.sliding_window = ((sliding_window, sliding_window)
|
||||||
if sliding_window is not None else (-1, -1))
|
if sliding_window is not None else (-1, -1))
|
||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
if logits_soft_cap is None:
|
||||||
|
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||||
|
logits_soft_cap = 0
|
||||||
|
self.logits_soft_cap = logits_soft_cap
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
@ -525,6 +522,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
causal=True,
|
causal=True,
|
||||||
window_size=self.sliding_window,
|
window_size=self.sliding_window,
|
||||||
alibi_slopes=self.alibi_slopes,
|
alibi_slopes=self.alibi_slopes,
|
||||||
|
softcap=self.logits_soft_cap,
|
||||||
)
|
)
|
||||||
assert output[:num_prefill_tokens].shape == out.shape
|
assert output[:num_prefill_tokens].shape == out.shape
|
||||||
output[:num_prefill_tokens] = out
|
output[:num_prefill_tokens] = out
|
||||||
@ -544,6 +542,7 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
causal=True,
|
causal=True,
|
||||||
alibi_slopes=self.alibi_slopes,
|
alibi_slopes=self.alibi_slopes,
|
||||||
block_table=prefill_meta.block_tables,
|
block_table=prefill_meta.block_tables,
|
||||||
|
softcap=self.logits_soft_cap,
|
||||||
)
|
)
|
||||||
|
|
||||||
if decode_meta := attn_metadata.decode_metadata:
|
if decode_meta := attn_metadata.decode_metadata:
|
||||||
|
@ -116,8 +116,6 @@ class FlashInferMetadata(AttentionMetadata):
|
|||||||
# The data type of the paged kv cache
|
# The data type of the paged kv cache
|
||||||
data_type: torch.dtype = None
|
data_type: torch.dtype = None
|
||||||
device: torch.device = torch.device("cuda")
|
device: torch.device = torch.device("cuda")
|
||||||
# Only used by gemma2 model
|
|
||||||
logits_soft_cap: Optional[float] = None
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Refer to
|
# Refer to
|
||||||
@ -391,9 +389,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=device)
|
device=device)
|
||||||
|
|
||||||
logits_soft_cap = getattr(self.runner.model_config.hf_config,
|
|
||||||
"attn_logit_softcapping", None)
|
|
||||||
|
|
||||||
if len(self.paged_kv_indptr) > 0:
|
if len(self.paged_kv_indptr) > 0:
|
||||||
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
|
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
|
||||||
device="cpu",
|
device="cpu",
|
||||||
@ -430,8 +425,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
|
|||||||
query_start_loc=query_start_loc,
|
query_start_loc=query_start_loc,
|
||||||
device=device,
|
device=device,
|
||||||
data_type=kv_cache_dtype,
|
data_type=kv_cache_dtype,
|
||||||
use_cuda_graph=use_captured_graph,
|
use_cuda_graph=use_captured_graph)
|
||||||
logits_soft_cap=logits_soft_cap)
|
|
||||||
|
|
||||||
|
|
||||||
class FlashInferImpl(AttentionImpl):
|
class FlashInferImpl(AttentionImpl):
|
||||||
@ -446,6 +440,7 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
sliding_window: Optional[int],
|
sliding_window: Optional[int],
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
|
logits_soft_cap: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
@ -458,6 +453,7 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
raise ValueError("Sliding window is not supported in FlashInfer.")
|
raise ValueError("Sliding window is not supported in FlashInfer.")
|
||||||
self.sliding_window = (-1, -1)
|
self.sliding_window = (-1, -1)
|
||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
self.logits_soft_cap = logits_soft_cap
|
||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
@ -532,7 +528,7 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
output = prefill_meta.prefill_wrapper.forward(
|
output = prefill_meta.prefill_wrapper.forward(
|
||||||
query,
|
query,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
logits_soft_cap=attn_metadata.logits_soft_cap,
|
logits_soft_cap=self.logits_soft_cap,
|
||||||
causal=True)
|
causal=True)
|
||||||
else:
|
else:
|
||||||
assert attn_metadata.decode_metadata is not None
|
assert attn_metadata.decode_metadata is not None
|
||||||
@ -541,5 +537,5 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
query,
|
query,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
sm_scale=self.scale,
|
sm_scale=self.scale,
|
||||||
logits_soft_cap=attn_metadata.logits_soft_cap)
|
logits_soft_cap=self.logits_soft_cap)
|
||||||
return output.view(num_tokens, hidden_size)
|
return output.view(num_tokens, hidden_size)
|
||||||
|
@ -105,9 +105,13 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
|||||||
sliding_window: Optional[int],
|
sliding_window: Optional[int],
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
|
logits_soft_cap: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert blocksparse_params is None, ValueError(
|
if blocksparse_params is not None:
|
||||||
"Torch SPDA does not support block-sparse attention.")
|
raise ValueError(
|
||||||
|
"IPEX backend does not support block-sparse attention.")
|
||||||
|
if logits_soft_cap is not None:
|
||||||
|
raise ValueError("IPEX backend does not support logits_soft_cap.")
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
|
@ -91,6 +91,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
sliding_window: Optional[int],
|
sliding_window: Optional[int],
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
|
logits_soft_cap: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
@ -109,6 +110,9 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
raise NotImplementedError("FP8 KV cache dtype is not supported.")
|
raise NotImplementedError("FP8 KV cache dtype is not supported.")
|
||||||
if blocksparse_params is not None:
|
if blocksparse_params is not None:
|
||||||
raise NotImplementedError("Blocksparse is not supported.")
|
raise NotImplementedError("Blocksparse is not supported.")
|
||||||
|
if logits_soft_cap is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Attention logits soft-capping is not supported.")
|
||||||
|
|
||||||
if torch_xla.tpu.version() < 4:
|
if torch_xla.tpu.version() < 4:
|
||||||
raise NotImplementedError("TPU version must be 4 or higher.")
|
raise NotImplementedError("TPU version must be 4 or higher.")
|
||||||
|
@ -244,9 +244,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
sliding_window: Optional[int],
|
sliding_window: Optional[int],
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
|
logits_soft_cap: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert blocksparse_params is None, ValueError(
|
if blocksparse_params is not None:
|
||||||
"ROCFlashAttention does not support blocksparse attention.")
|
raise ValueError(
|
||||||
|
"ROCmFlashAttention does not support blocksparse attention.")
|
||||||
|
if logits_soft_cap is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"ROCmFlashAttention does not support attention logits soft "
|
||||||
|
"capping.")
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
|
@ -109,9 +109,13 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
sliding_window: Optional[int],
|
sliding_window: Optional[int],
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
|
logits_soft_cap: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert blocksparse_params is None, ValueError(
|
if blocksparse_params is not None:
|
||||||
"Torch SPDA does not support block-sparse attention.")
|
raise ValueError(
|
||||||
|
"Torch SPDA does not support block-sparse attention.")
|
||||||
|
if logits_soft_cap is not None:
|
||||||
|
raise ValueError("Torch SPDA does not support logits soft cap.")
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
|
@ -165,15 +165,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
|
|||||||
device = self.runner.device
|
device = self.runner.device
|
||||||
use_captured_graph = cuda_graph_pad_size != -1
|
use_captured_graph = cuda_graph_pad_size != -1
|
||||||
|
|
||||||
logits_soft_cap = getattr(self.runner.model_config.hf_config,
|
|
||||||
"attn_logit_softcapping", None)
|
|
||||||
if logits_soft_cap is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"Please use Flashinfer backend for models with logits_soft_cap "
|
|
||||||
"(i.e., Gemma-2). Otherwise, the output might be wrong. "
|
|
||||||
"Set Flashinfer backend by "
|
|
||||||
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")
|
|
||||||
|
|
||||||
max_query_len = max(query_lens)
|
max_query_len = max(query_lens)
|
||||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||||
|
@ -408,9 +408,14 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
sliding_window: Optional[int],
|
sliding_window: Optional[int],
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
|
logits_soft_cap: Optional[float] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert blocksparse_params is None, ValueError(
|
if blocksparse_params is not None:
|
||||||
"XFormer does not support block-sparse attention.")
|
raise ValueError(
|
||||||
|
"XFormers does not support block-sparse attention.")
|
||||||
|
if logits_soft_cap is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"XFormers does not support attention logits soft capping.")
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
|
@ -34,6 +34,7 @@ class Attention(nn.Module):
|
|||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
|
logits_soft_cap: Optional[float] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -82,7 +83,7 @@ class Attention(nn.Module):
|
|||||||
impl_cls = attn_backend.get_impl_cls()
|
impl_cls = attn_backend.get_impl_cls()
|
||||||
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
||||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||||
blocksparse_params)
|
blocksparse_params, logits_soft_cap)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -90,7 +90,8 @@ class Gemma2Attention(nn.Module):
|
|||||||
max_position_embeddings: int,
|
max_position_embeddings: int,
|
||||||
rope_theta: float,
|
rope_theta: float,
|
||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
attn_logits_soft_cap: Optional[float] = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -150,7 +151,8 @@ class Gemma2Attention(nn.Module):
|
|||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
logits_soft_cap=attn_logits_soft_cap)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -189,6 +191,7 @@ class Gemma2DecoderLayer(nn.Module):
|
|||||||
rope_theta=config.rope_theta,
|
rope_theta=config.rope_theta,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
attn_logits_soft_cap=config.attn_logit_softcapping,
|
||||||
)
|
)
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.mlp = Gemma2MLP(
|
self.mlp = Gemma2MLP(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user