[V1][TPU] Change kv cache shape. (#15145)
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
This commit is contained in:
parent
8310e0b59b
commit
b0e96aaebb
@ -17,9 +17,9 @@ ray[data]
|
||||
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
||||
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
||||
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250314%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250314%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250314%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250314%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250314%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250314%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250319-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250319-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250319-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250319-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250319-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250319-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
|
@ -41,7 +41,7 @@ class PallasAttentionBackend(AttentionBackend):
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> tuple[int, ...]:
|
||||
return (num_blocks, block_size, num_kv_heads, head_size)
|
||||
return (num_blocks, block_size, num_kv_heads * head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
@ -142,8 +142,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = ([num_blocks, block_size, num_kv_heads, head_size],
|
||||
[num_blocks, block_size, num_kv_heads, head_size])
|
||||
kv_cache = ([num_blocks, block_size, num_kv_heads * head_size],
|
||||
[num_blocks, block_size, num_kv_heads * head_size])
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
@ -157,8 +157,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
num_tokens, hidden_size = query.shape
|
||||
query = query.view(num_tokens, self.num_heads, self.head_size)
|
||||
key = key.view(num_tokens, self.num_kv_heads, self.head_size)
|
||||
value = value.view(num_tokens, self.num_kv_heads, self.head_size)
|
||||
|
||||
key_cache, value_cache = kv_cache
|
||||
if kv_cache[0].numel() > 0:
|
||||
@ -192,10 +190,10 @@ def write_to_kv_cache(
|
||||
""" Write the key and values to the KV cache.
|
||||
|
||||
Args:
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
k_cache = [num_blocks, block_size, num_kv_heads, head_size]
|
||||
v_cache = [num_blocks, block_size, num_kv_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
k_cache = [num_blocks, block_size, num_kv_heads * head_size]
|
||||
v_cache = [num_blocks, block_size, num_kv_heads * head_size]
|
||||
|
||||
"""
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
|
||||
@ -203,6 +201,5 @@ def write_to_kv_cache(
|
||||
|
||||
key_cache = key_cache.flatten(0, 1)
|
||||
value_cache = value_cache.flatten(0, 1)
|
||||
slot_mapping = slot_mapping.flatten()
|
||||
key_cache.index_copy_(0, slot_mapping, key)
|
||||
value_cache.index_copy_(0, slot_mapping, value)
|
||||
|
Loading…
x
Reference in New Issue
Block a user