[Kernel] Pipe attn_logits_soft_cap through paged attention TPU kernels (#12482)

Signed-off-by: Fenghui Zhang <fhzhang@google.com>
This commit is contained in:
fenghuizhang 2025-01-28 14:36:44 -08:00 committed by GitHub
parent c386c43ca3
commit 80fcc3ed1c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 26 deletions

0
.buildkite/run-tpu-test.sh Normal file → Executable file
View File

View File

@ -110,6 +110,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
assert self.num_heads % self.num_kv_heads == 0 assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.logits_soft_cap = logits_soft_cap
if head_size % 128 != 0: if head_size % 128 != 0:
raise NotImplementedError("Head size must be a multiple of 128.") raise NotImplementedError("Head size must be a multiple of 128.")
if alibi_slopes is not None: if alibi_slopes is not None:
@ -120,9 +121,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
raise NotImplementedError("FP8 KV cache dtype is not supported.") raise NotImplementedError("FP8 KV cache dtype is not supported.")
if blocksparse_params is not None: if blocksparse_params is not None:
raise NotImplementedError("Blocksparse is not supported.") raise NotImplementedError("Blocksparse is not supported.")
if logits_soft_cap is not None:
raise NotImplementedError(
"Attention logits soft-capping is not supported.")
if torch_xla.tpu.version() < 4: if torch_xla.tpu.version() < 4:
raise NotImplementedError("TPU version must be 4 or higher.") raise NotImplementedError("TPU version must be 4 or higher.")
@ -230,6 +228,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
num_kv_pages_per_compute_block, num_kv_pages_per_compute_block,
num_queries_per_compute_block, num_queries_per_compute_block,
use_kernel=True, use_kernel=True,
attn_logits_soft_cap=self.logits_soft_cap,
) )
else: else:
# Decoding run. # Decoding run.
@ -257,6 +256,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
attn_metadata.block_tables, attn_metadata.block_tables,
pages_per_compute_block, pages_per_compute_block,
self.megacore_mode, self.megacore_mode,
attn_logits_soft_cap=self.logits_soft_cap,
) )
else: else:
chunk_size = max_num_seq chunk_size = max_num_seq
@ -280,6 +280,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
attn_metadata.block_tables[chunk_start:chunk_end], attn_metadata.block_tables[chunk_start:chunk_end],
pages_per_compute_block, pages_per_compute_block,
self.megacore_mode, self.megacore_mode,
attn_logits_soft_cap=self.logits_soft_cap,
) )
output[chunk_start:chunk_end] = chunk_output output[chunk_start:chunk_end] = chunk_output
@ -313,6 +314,8 @@ def paged_attention(
block_tables: torch.Tensor, block_tables: torch.Tensor,
pages_per_compute_block: int, pages_per_compute_block: int,
megacore_mode: Optional[str], megacore_mode: Optional[str],
*,
attn_logits_soft_cap: Optional[float],
) -> torch.Tensor: ) -> torch.Tensor:
batch_size = query.shape[0] batch_size = query.shape[0]
if megacore_mode == "batch" and batch_size % 2 != 0: if megacore_mode == "batch" and batch_size % 2 != 0:
@ -320,11 +323,7 @@ def paged_attention(
else: else:
megacore_mode = megacore_mode megacore_mode = megacore_mode
# NOTE(woosuk): A temporary workaround to avoid the error: return torch.ops.xla.paged_attention(
# "xla::paged_attention() Expected a value of type 'str' for
# argument 'megacore_mode' but instead found type 'NoneType'."
if megacore_mode is not None:
output = torch.ops.xla.paged_attention(
query, query,
key_cache, key_cache,
value_cache, value_cache,
@ -332,14 +331,5 @@ def paged_attention(
block_tables, block_tables,
pages_per_compute_block, pages_per_compute_block,
megacore_mode=megacore_mode, megacore_mode=megacore_mode,
attn_logits_soft_cap=attn_logits_soft_cap,
) )
else:
output = torch.ops.xla.paged_attention(
query,
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
)
return output