[Kernel] Pipe attn_logits_soft_cap through paged attention TPU kernels (#12482)
Signed-off-by: Fenghui Zhang <fhzhang@google.com>
This commit is contained in:
parent
c386c43ca3
commit
80fcc3ed1c
0
.buildkite/run-tpu-test.sh
Normal file → Executable file
0
.buildkite/run-tpu-test.sh
Normal file → Executable file
@ -110,6 +110,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
|
|
||||||
assert self.num_heads % self.num_kv_heads == 0
|
assert self.num_heads % self.num_kv_heads == 0
|
||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
self.logits_soft_cap = logits_soft_cap
|
||||||
if head_size % 128 != 0:
|
if head_size % 128 != 0:
|
||||||
raise NotImplementedError("Head size must be a multiple of 128.")
|
raise NotImplementedError("Head size must be a multiple of 128.")
|
||||||
if alibi_slopes is not None:
|
if alibi_slopes is not None:
|
||||||
@ -120,9 +121,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
raise NotImplementedError("FP8 KV cache dtype is not supported.")
|
raise NotImplementedError("FP8 KV cache dtype is not supported.")
|
||||||
if blocksparse_params is not None:
|
if blocksparse_params is not None:
|
||||||
raise NotImplementedError("Blocksparse is not supported.")
|
raise NotImplementedError("Blocksparse is not supported.")
|
||||||
if logits_soft_cap is not None:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Attention logits soft-capping is not supported.")
|
|
||||||
|
|
||||||
if torch_xla.tpu.version() < 4:
|
if torch_xla.tpu.version() < 4:
|
||||||
raise NotImplementedError("TPU version must be 4 or higher.")
|
raise NotImplementedError("TPU version must be 4 or higher.")
|
||||||
@ -230,6 +228,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
num_kv_pages_per_compute_block,
|
num_kv_pages_per_compute_block,
|
||||||
num_queries_per_compute_block,
|
num_queries_per_compute_block,
|
||||||
use_kernel=True,
|
use_kernel=True,
|
||||||
|
attn_logits_soft_cap=self.logits_soft_cap,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Decoding run.
|
# Decoding run.
|
||||||
@ -257,6 +256,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
attn_metadata.block_tables,
|
attn_metadata.block_tables,
|
||||||
pages_per_compute_block,
|
pages_per_compute_block,
|
||||||
self.megacore_mode,
|
self.megacore_mode,
|
||||||
|
attn_logits_soft_cap=self.logits_soft_cap,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
chunk_size = max_num_seq
|
chunk_size = max_num_seq
|
||||||
@ -280,6 +280,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
attn_metadata.block_tables[chunk_start:chunk_end],
|
attn_metadata.block_tables[chunk_start:chunk_end],
|
||||||
pages_per_compute_block,
|
pages_per_compute_block,
|
||||||
self.megacore_mode,
|
self.megacore_mode,
|
||||||
|
attn_logits_soft_cap=self.logits_soft_cap,
|
||||||
)
|
)
|
||||||
output[chunk_start:chunk_end] = chunk_output
|
output[chunk_start:chunk_end] = chunk_output
|
||||||
|
|
||||||
@ -313,6 +314,8 @@ def paged_attention(
|
|||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
pages_per_compute_block: int,
|
pages_per_compute_block: int,
|
||||||
megacore_mode: Optional[str],
|
megacore_mode: Optional[str],
|
||||||
|
*,
|
||||||
|
attn_logits_soft_cap: Optional[float],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
batch_size = query.shape[0]
|
batch_size = query.shape[0]
|
||||||
if megacore_mode == "batch" and batch_size % 2 != 0:
|
if megacore_mode == "batch" and batch_size % 2 != 0:
|
||||||
@ -320,11 +323,7 @@ def paged_attention(
|
|||||||
else:
|
else:
|
||||||
megacore_mode = megacore_mode
|
megacore_mode = megacore_mode
|
||||||
|
|
||||||
# NOTE(woosuk): A temporary workaround to avoid the error:
|
return torch.ops.xla.paged_attention(
|
||||||
# "xla::paged_attention() Expected a value of type 'str' for
|
|
||||||
# argument 'megacore_mode' but instead found type 'NoneType'."
|
|
||||||
if megacore_mode is not None:
|
|
||||||
output = torch.ops.xla.paged_attention(
|
|
||||||
query,
|
query,
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
@ -332,14 +331,5 @@ def paged_attention(
|
|||||||
block_tables,
|
block_tables,
|
||||||
pages_per_compute_block,
|
pages_per_compute_block,
|
||||||
megacore_mode=megacore_mode,
|
megacore_mode=megacore_mode,
|
||||||
|
attn_logits_soft_cap=attn_logits_soft_cap,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
output = torch.ops.xla.paged_attention(
|
|
||||||
query,
|
|
||||||
key_cache,
|
|
||||||
value_cache,
|
|
||||||
context_lens,
|
|
||||||
block_tables,
|
|
||||||
pages_per_compute_block,
|
|
||||||
)
|
|
||||||
return output
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user