diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 1f60d540..1d00f0c1 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -8,4 +8,4 @@ torch == 2.4.0 # These must be updated alongside torch torchvision == 0.19 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version xformers == 0.0.27.post2 # Requires PyTorch 2.4.0 -vllm-flash-attn == 2.6.0 # Requires PyTorch 2.4.0 +vllm-flash-attn == 2.6.1 # Requires PyTorch 2.4.0 diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index cd06c271..6c5eff00 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -20,6 +20,7 @@ def ref_paged_attn( block_tables: torch.Tensor, scale: float, sliding_window: Optional[int] = None, + soft_cap: Optional[float] = None, ) -> torch.Tensor: num_seqs = len(query_lens) block_tables = block_tables.cpu().numpy() @@ -53,6 +54,8 @@ def ref_paged_attn( (query_len + sliding_window) + 1).bool().logical_not() mask |= sliding_window_mask + if soft_cap is not None: + attn = soft_cap * torch.tanh(attn / soft_cap) attn.masked_fill_(mask, float("-inf")) attn = torch.softmax(attn, dim=-1).to(v.dtype) out = torch.einsum("hqk,khd->qhd", attn, v) @@ -68,13 +71,15 @@ def ref_paged_attn( @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) -@torch.inference_mode +@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) +@torch.inference_mode() def test_flash_attn_with_paged_kv( kv_lens: List[int], num_heads: Tuple[int, int], head_size: int, dtype: torch.dtype, block_size: int, + soft_cap: Optional[float], ) -> None: torch.set_default_device("cuda") torch.cuda.manual_seed_all(0) @@ -108,6 +113,7 @@ def test_flash_attn_with_paged_kv( causal=True, block_table=block_tables, cache_seqlens=kv_lens_tensor, + softcap=soft_cap if soft_cap is not None else 0, ).squeeze(1) ref_output = ref_paged_attn( @@ -118,6 +124,7 @@ def test_flash_attn_with_paged_kv( kv_lens=kv_lens, block_tables=block_tables, scale=scale, + soft_cap=soft_cap, ) assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ f"{torch.max(torch.abs(output - ref_output))}" @@ -129,7 +136,8 @@ def test_flash_attn_with_paged_kv( @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("sliding_window", [None]) @pytest.mark.parametrize("dtype", DTYPES) -@torch.inference_mode +@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) +@torch.inference_mode() def test_varlen_with_paged_kv( seq_lens: List[Tuple[int, int]], num_heads: Tuple[int, int], @@ -137,6 +145,7 @@ def test_varlen_with_paged_kv( sliding_window: Optional[int], dtype: torch.dtype, block_size: int, + soft_cap: Optional[float], ) -> None: torch.set_default_device("cuda") torch.cuda.manual_seed_all(0) @@ -163,10 +172,6 @@ def test_varlen_with_paged_kv( head_size, dtype=dtype) value_cache = torch.randn_like(key_cache) - # Normalize the scale of the key and value caches to mitigate - # numerical instability. - key_cache /= head_size**0.5 - value_cache /= head_size**0.5 cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(dim=0, dtype=torch.int32) @@ -192,6 +197,7 @@ def test_varlen_with_paged_kv( causal=True, window_size=window_size, block_table=block_tables, + softcap=soft_cap if soft_cap is not None else 0, ) ref_output = ref_paged_attn( @@ -203,6 +209,7 @@ def test_varlen_with_paged_kv( block_tables=block_tables, scale=scale, sliding_window=sliding_window, + soft_cap=soft_cap, ) assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \ f"{torch.max(torch.abs(output - ref_output))}" diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 106b00cc..97b13917 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -150,6 +150,7 @@ class AttentionImpl(ABC, Generic[T]): sliding_window: Optional[int] = None, kv_cache_dtype: str = "auto", blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, ) -> None: raise NotImplementedError diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 71954f86..907b4539 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -283,12 +283,15 @@ class BlocksparseFlashAttentionImpl(AttentionImpl): sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, ) -> None: assert blocksparse_params is not None assert alibi_slopes is None, ValueError( "Alibi not support for blocksparse flash attention.") assert sliding_window is None, ValueError( "sliding_window is invalid for blocksparse attention.") + assert logits_soft_cap is None, ValueError( + "logits_soft_cap is invalid for blocksparse attention.") if "num_heads" not in blocksparse_params: blocksparse_params["num_heads"] = num_heads diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 7d7aff9d..00654dca 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -288,15 +288,6 @@ class FlashAttentionMetadataBuilder( device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 - logits_soft_cap = getattr(self.runner.model_config.hf_config, - "attn_logit_softcapping", None) - if logits_soft_cap is not None: - raise ValueError( - "Please use Flashinfer backend for models with logits_soft_cap" - " (i.e., Gemma-2). Otherwise, the output might be wrong." - " Set Flashinfer backend by " - "export VLLM_ATTENTION_BACKEND=FLASHINFER.") - max_query_len = max(query_lens) max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) @@ -405,9 +396,11 @@ class FlashAttentionImpl(AttentionImpl): sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, ) -> None: - assert blocksparse_params is None, ValueError( - "FlashAttention does not support block-sparse attention.") + if blocksparse_params is not None: + raise ValueError( + "FlashAttention does not support block-sparse attention.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) @@ -418,6 +411,10 @@ class FlashAttentionImpl(AttentionImpl): self.sliding_window = ((sliding_window, sliding_window) if sliding_window is not None else (-1, -1)) self.kv_cache_dtype = kv_cache_dtype + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -525,6 +522,7 @@ class FlashAttentionImpl(AttentionImpl): causal=True, window_size=self.sliding_window, alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, ) assert output[:num_prefill_tokens].shape == out.shape output[:num_prefill_tokens] = out @@ -544,6 +542,7 @@ class FlashAttentionImpl(AttentionImpl): causal=True, alibi_slopes=self.alibi_slopes, block_table=prefill_meta.block_tables, + softcap=self.logits_soft_cap, ) if decode_meta := attn_metadata.decode_metadata: diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 83a420d7..ccf8ab03 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -116,8 +116,6 @@ class FlashInferMetadata(AttentionMetadata): # The data type of the paged kv cache data_type: torch.dtype = None device: torch.device = torch.device("cuda") - # Only used by gemma2 model - logits_soft_cap: Optional[float] = None def __post_init__(self): # Refer to @@ -391,9 +389,6 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): dtype=torch.long, device=device) - logits_soft_cap = getattr(self.runner.model_config.hf_config, - "attn_logit_softcapping", None) - if len(self.paged_kv_indptr) > 0: paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, device="cpu", @@ -430,8 +425,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): query_start_loc=query_start_loc, device=device, data_type=kv_cache_dtype, - use_cuda_graph=use_captured_graph, - logits_soft_cap=logits_soft_cap) + use_cuda_graph=use_captured_graph) class FlashInferImpl(AttentionImpl): @@ -446,6 +440,7 @@ class FlashInferImpl(AttentionImpl): sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -458,6 +453,7 @@ class FlashInferImpl(AttentionImpl): raise ValueError("Sliding window is not supported in FlashInfer.") self.sliding_window = (-1, -1) self.kv_cache_dtype = kv_cache_dtype + self.logits_soft_cap = logits_soft_cap assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -532,7 +528,7 @@ class FlashInferImpl(AttentionImpl): output = prefill_meta.prefill_wrapper.forward( query, kv_cache, - logits_soft_cap=attn_metadata.logits_soft_cap, + logits_soft_cap=self.logits_soft_cap, causal=True) else: assert attn_metadata.decode_metadata is not None @@ -541,5 +537,5 @@ class FlashInferImpl(AttentionImpl): query, kv_cache, sm_scale=self.scale, - logits_soft_cap=attn_metadata.logits_soft_cap) + logits_soft_cap=self.logits_soft_cap) return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 4559dd15..bac30aec 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -105,9 +105,13 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, ) -> None: - assert blocksparse_params is None, ValueError( - "Torch SPDA does not support block-sparse attention.") + if blocksparse_params is not None: + raise ValueError( + "IPEX backend does not support block-sparse attention.") + if logits_soft_cap is not None: + raise ValueError("IPEX backend does not support logits_soft_cap.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 2269ac26..4ecf698c 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -91,6 +91,7 @@ class PallasAttentionBackendImpl(AttentionImpl): sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -109,6 +110,9 @@ class PallasAttentionBackendImpl(AttentionImpl): raise NotImplementedError("FP8 KV cache dtype is not supported.") if blocksparse_params is not None: raise NotImplementedError("Blocksparse is not supported.") + if logits_soft_cap is not None: + raise NotImplementedError( + "Attention logits soft-capping is not supported.") if torch_xla.tpu.version() < 4: raise NotImplementedError("TPU version must be 4 or higher.") diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 058c8df0..26e9b8a9 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -244,9 +244,15 @@ class ROCmFlashAttentionImpl(AttentionImpl): sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, ) -> None: - assert blocksparse_params is None, ValueError( - "ROCFlashAttention does not support blocksparse attention.") + if blocksparse_params is not None: + raise ValueError( + "ROCmFlashAttention does not support blocksparse attention.") + if logits_soft_cap is not None: + raise ValueError( + "ROCmFlashAttention does not support attention logits soft " + "capping.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index fe6a5612..b83c673f 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -109,9 +109,13 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, ) -> None: - assert blocksparse_params is None, ValueError( - "Torch SPDA does not support block-sparse attention.") + if blocksparse_params is not None: + raise ValueError( + "Torch SPDA does not support block-sparse attention.") + if logits_soft_cap is not None: + raise ValueError("Torch SPDA does not support logits soft cap.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index dcd10ed4..bca13703 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -165,15 +165,6 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 - logits_soft_cap = getattr(self.runner.model_config.hf_config, - "attn_logit_softcapping", None) - if logits_soft_cap is not None: - raise ValueError( - "Please use Flashinfer backend for models with logits_soft_cap " - "(i.e., Gemma-2). Otherwise, the output might be wrong. " - "Set Flashinfer backend by " - "export VLLM_ATTENTION_BACKEND=FLASHINFER.") - max_query_len = max(query_lens) max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 1573cd7d..24ba5fc7 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -408,9 +408,14 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, ) -> None: - assert blocksparse_params is None, ValueError( - "XFormer does not support block-sparse attention.") + if blocksparse_params is not None: + raise ValueError( + "XFormers does not support block-sparse attention.") + if logits_soft_cap is not None: + raise ValueError( + "XFormers does not support attention logits soft capping.") self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 5fa552f2..2c21502d 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -34,6 +34,7 @@ class Attention(nn.Module): cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, prefix: str = "", ) -> None: super().__init__() @@ -82,7 +83,7 @@ class Attention(nn.Module): impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, - blocksparse_params) + blocksparse_params, logits_soft_cap) def forward( self, diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index b77c901f..7bad2626 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -90,7 +90,8 @@ class Gemma2Attention(nn.Module): max_position_embeddings: int, rope_theta: float, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + quant_config: Optional[QuantizationConfig] = None, + attn_logits_soft_cap: Optional[float] = None) -> None: super().__init__() self.layer_idx = layer_idx self.config = config @@ -150,7 +151,8 @@ class Gemma2Attention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + logits_soft_cap=attn_logits_soft_cap) def forward( self, @@ -189,6 +191,7 @@ class Gemma2DecoderLayer(nn.Module): rope_theta=config.rope_theta, cache_config=cache_config, quant_config=quant_config, + attn_logits_soft_cap=config.attn_logit_softcapping, ) self.hidden_size = config.hidden_size self.mlp = Gemma2MLP(