[Kernel][TPU][ragged-paged-attn] vLLM code change for PR#8896 (#15659)
Signed-off-by: Yarong Mu <ymu@google.com>
This commit is contained in:
parent
da461f3cbf
commit
7c1f760024
@ -17,9 +17,9 @@ ray[data]
|
|||||||
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
||||||
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
||||||
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||||
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250319-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250328-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||||
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250319-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250328-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||||
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250319-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250328-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250319-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250328-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250319-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250328-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250319-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250328-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||||
|
@ -41,7 +41,7 @@ class PallasAttentionBackend(AttentionBackend):
|
|||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
) -> tuple[int, ...]:
|
) -> tuple[int, ...]:
|
||||||
return (num_blocks, block_size, num_kv_heads * head_size)
|
return (num_blocks, block_size, num_kv_heads * 2, head_size)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def swap_blocks(
|
def swap_blocks(
|
||||||
@ -132,7 +132,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
kv_cache: tuple[torch.Tensor, torch.Tensor],
|
kv_cache: torch.Tensor,
|
||||||
attn_metadata: PallasMetadata,
|
attn_metadata: PallasMetadata,
|
||||||
output: Optional[torch.Tensor] = None,
|
output: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -142,14 +142,13 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
query: shape = [num_tokens, num_heads * head_size]
|
query: shape = [num_tokens, num_heads * head_size]
|
||||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
kv_cache = ([num_blocks, block_size, num_kv_heads * head_size],
|
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
|
||||||
[num_blocks, block_size, num_kv_heads * head_size])
|
|
||||||
attn_metadata: Metadata for attention.
|
attn_metadata: Metadata for attention.
|
||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
# For determine_available_memory case.
|
# For determine_available_memory case.
|
||||||
if kv_cache[0].numel() == 0:
|
if kv_cache.numel() == 0:
|
||||||
if output is None:
|
if output is None:
|
||||||
output = torch.ones_like(query)
|
output = torch.ones_like(query)
|
||||||
return output
|
return output
|
||||||
@ -158,15 +157,13 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
num_tokens, hidden_size = query.shape
|
num_tokens, hidden_size = query.shape
|
||||||
query = query.view(num_tokens, self.num_heads, self.head_size)
|
query = query.view(num_tokens, self.num_heads, self.head_size)
|
||||||
|
|
||||||
key_cache, value_cache = kv_cache
|
if kv_cache.numel() > 0:
|
||||||
if kv_cache[0].numel() > 0:
|
|
||||||
slot_mapping = attn_metadata.slot_mapping
|
slot_mapping = attn_metadata.slot_mapping
|
||||||
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
|
write_to_kv_cache(key, value, kv_cache, slot_mapping)
|
||||||
|
|
||||||
output = torch.ops.xla.ragged_paged_attention(
|
output = torch.ops.xla.ragged_paged_attention(
|
||||||
query,
|
query,
|
||||||
key_cache,
|
kv_cache,
|
||||||
value_cache,
|
|
||||||
attn_metadata.context_lens,
|
attn_metadata.context_lens,
|
||||||
attn_metadata.block_tables,
|
attn_metadata.block_tables,
|
||||||
attn_metadata.query_start_loc,
|
attn_metadata.query_start_loc,
|
||||||
@ -183,23 +180,27 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
def write_to_kv_cache(
|
def write_to_kv_cache(
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
key_cache: torch.Tensor,
|
kv_cache: torch.Tensor,
|
||||||
value_cache: torch.Tensor,
|
|
||||||
slot_mapping: torch.Tensor,
|
slot_mapping: torch.Tensor,
|
||||||
) -> None:
|
) -> None:
|
||||||
""" Write the key and values to the KV cache.
|
""" Write the key and values to the KV cache.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
k_cache = [num_blocks, block_size, num_kv_heads * head_size]
|
kv_cache = [num_blocks, block_size, num_kv_heads * 2, head_size]
|
||||||
v_cache = [num_blocks, block_size, num_kv_heads * head_size]
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
|
_, _, num_combined_kv_heads, head_size = kv_cache.shape
|
||||||
torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)
|
num_kv_heads = num_combined_kv_heads // 2
|
||||||
|
|
||||||
key_cache = key_cache.flatten(0, 1)
|
key = key.view(-1, num_kv_heads, head_size)
|
||||||
value_cache = value_cache.flatten(0, 1)
|
value = value.view(-1, num_kv_heads, head_size)
|
||||||
key_cache.index_copy_(0, slot_mapping, key)
|
|
||||||
value_cache.index_copy_(0, slot_mapping, value)
|
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads,
|
||||||
|
head_size)
|
||||||
|
|
||||||
|
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True)
|
||||||
|
|
||||||
|
kv_cache = kv_cache.flatten(0, 1)
|
||||||
|
kv_cache.index_copy_(0, slot_mapping, kv)
|
||||||
|
@ -861,12 +861,11 @@ class TPUModelRunner:
|
|||||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||||
dtype = kv_cache_spec.dtype
|
dtype = kv_cache_spec.dtype
|
||||||
|
|
||||||
tpu_k_cache = torch.zeros(kv_cache_shape,
|
tpu_kv_cache = torch.zeros(kv_cache_shape,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
tpu_v_cache = torch.zeros_like(tpu_k_cache)
|
|
||||||
|
|
||||||
kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache)
|
kv_caches[layer_name] = tpu_kv_cache
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -893,7 +892,7 @@ class ModelWrapperV1(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
|
kv_caches: list[torch.Tensor],
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Executes the forward pass of the model.
|
"""Executes the forward pass of the model.
|
||||||
|
@ -136,10 +136,10 @@ class TPUWorker:
|
|||||||
|
|
||||||
# Use an empty tensor instead of `None`` to force Dynamo to pass
|
# Use an empty tensor instead of `None`` to force Dynamo to pass
|
||||||
# it by reference, rather by specializing on the value ``None``.
|
# it by reference, rather by specializing on the value ``None``.
|
||||||
tpu_k_cache = torch.tensor([], dtype=dtype, device=self.device)
|
tpu_kv_cache = torch.tensor([],
|
||||||
tpu_v_cache = torch.tensor([], dtype=dtype, device=self.device)
|
dtype=dtype,
|
||||||
|
device=self.device)
|
||||||
kv_caches[layer_name] = (tpu_k_cache, tpu_v_cache)
|
kv_caches[layer_name] = tpu_kv_cache
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user