[Bugfix] Fix FP8 KV cache support (#4869)
This commit is contained in:
parent
2060e93659
commit
9a31a817a8
@ -200,15 +200,15 @@ class FlashAttentionImpl(AttentionImpl):
|
|||||||
num_heads: int,
|
num_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
scale: float,
|
scale: float,
|
||||||
num_kv_heads: Optional[int] = None,
|
num_kv_heads: int,
|
||||||
alibi_slopes: Optional[List[float]] = None,
|
alibi_slopes: Optional[List[float]],
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int],
|
||||||
kv_cache_dtype: str = "auto",
|
kv_cache_dtype: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
self.num_kv_heads = num_kv_heads
|
||||||
if alibi_slopes is not None:
|
if alibi_slopes is not None:
|
||||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||||
self.alibi_slopes = alibi_slopes
|
self.alibi_slopes = alibi_slopes
|
||||||
|
@ -164,15 +164,15 @@ class FlashInferImpl(AttentionImpl):
|
|||||||
num_heads: int,
|
num_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
scale: float,
|
scale: float,
|
||||||
num_kv_heads: Optional[int] = None,
|
num_kv_heads: int,
|
||||||
alibi_slopes: Optional[List[float]] = None,
|
alibi_slopes: Optional[List[float]],
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int],
|
||||||
kv_cache_dtype: str = "auto",
|
kv_cache_dtype: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
self.num_kv_heads = num_kv_heads
|
||||||
if alibi_slopes is not None:
|
if alibi_slopes is not None:
|
||||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||||
self.alibi_slopes = alibi_slopes
|
self.alibi_slopes = alibi_slopes
|
||||||
|
@ -197,15 +197,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
num_heads: int,
|
num_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
scale: float,
|
scale: float,
|
||||||
num_kv_heads: Optional[int] = None,
|
num_kv_heads: int,
|
||||||
alibi_slopes: Optional[List[float]] = None,
|
alibi_slopes: Optional[List[float]],
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int],
|
||||||
kv_cache_dtype: str = "auto",
|
kv_cache_dtype: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
self.num_kv_heads = num_kv_heads
|
||||||
if alibi_slopes is not None:
|
if alibi_slopes is not None:
|
||||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||||
self.alibi_slopes = alibi_slopes
|
self.alibi_slopes = alibi_slopes
|
||||||
|
@ -96,15 +96,15 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
|||||||
num_heads: int,
|
num_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
scale: float,
|
scale: float,
|
||||||
num_kv_heads: Optional[int] = None,
|
num_kv_heads: int,
|
||||||
alibi_slopes: Optional[List[float]] = None,
|
alibi_slopes: Optional[List[float]],
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int],
|
||||||
kv_cache_dtype: str = "auto",
|
kv_cache_dtype: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
self.num_kv_heads = num_kv_heads
|
||||||
if alibi_slopes is not None:
|
if alibi_slopes is not None:
|
||||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||||
self.alibi_slopes = alibi_slopes
|
self.alibi_slopes = alibi_slopes
|
||||||
|
@ -208,15 +208,15 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
num_heads: int,
|
num_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
scale: float,
|
scale: float,
|
||||||
num_kv_heads: Optional[int] = None,
|
num_kv_heads: int,
|
||||||
alibi_slopes: Optional[List[float]] = None,
|
alibi_slopes: Optional[List[float]],
|
||||||
sliding_window: Optional[int] = None,
|
sliding_window: Optional[int],
|
||||||
kv_cache_dtype: str = "auto",
|
kv_cache_dtype: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
self.num_kv_heads = num_kv_heads
|
||||||
if alibi_slopes is not None:
|
if alibi_slopes is not None:
|
||||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||||
self.alibi_slopes = alibi_slopes
|
self.alibi_slopes = alibi_slopes
|
||||||
|
@ -48,7 +48,7 @@ class Attention(nn.Module):
|
|||||||
block_size)
|
block_size)
|
||||||
impl_cls = attn_backend.get_impl_cls()
|
impl_cls = attn_backend.get_impl_cls()
|
||||||
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
||||||
alibi_slopes, sliding_window)
|
alibi_slopes, sliding_window, kv_cache_dtype)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user