[V1][TPU] Integrate the new ragged paged attention kernel with vLLM v1 on TPU (#13379)

Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
iefgnoix 2025-02-28 10:01:36 -08:00 committed by GitHub
parent 4be4b26cb7
commit c3b6559a10
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 335 additions and 887 deletions

View File

@ -17,9 +17,8 @@ ray[default]
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch @ https://download.pytorch.org/whl/nightly/cpu/torch-2.6.0.dev20241216%2Bcpu-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch @ https://download.pytorch.org/whl/nightly/cpu/torch-2.6.0.dev20241216%2Bcpu-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch @ https://download.pytorch.org/whl/nightly/cpu/torch-2.6.0.dev20241216%2Bcpu-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250124-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch==2.7.0.dev20250226+cpu
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"

View File

@ -4,13 +4,16 @@ from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
import torch_xla.experimental.custom_kernel # Required to register custom ops.
# Required to register custom ops.
import torch_xla.experimental.custom_kernel # noqa: F401
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
NUM_QUERIES_PER_BLOCK = 16
NUM_KV_PAGES_PER_BLOCK = 128
class PallasAttentionBackend(AttentionBackend):
@ -47,47 +50,23 @@ class PallasAttentionBackend(AttentionBackend):
) -> None:
raise RuntimeError("swap_blocks is not used for the TPU backend.")
@torch.compile(backend="openxla")
@staticmethod
def copy_blocks(
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
src_to_dists: Tuple[torch.Tensor, torch.Tensor],
) -> None:
src_indices, dst_indices = src_to_dists
for k_cache, v_cache in kv_caches:
torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
k_cache[:, dst_indices] = k_cache[:, src_indices]
torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
v_cache[:, dst_indices] = v_cache[:, src_indices]
@dataclass
class PallasMetadata(AttentionMetadata):
class PallasMetadata:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
# Currently, input sequences can only contain all prefills
# or all decoding.
block_tables: Optional[torch.Tensor] = None
context_lens: Optional[torch.Tensor] = None
effective_query_lens: Optional[torch.Tensor] = None
@property
def prefill_metadata(self) -> Optional["PallasMetadata"]:
if self.num_prefills == 0:
return None
assert self.num_decode_tokens == 0
return self
@property
def decode_metadata(self) -> Optional["PallasMetadata"]:
if self.num_decode_tokens == 0:
return None
assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.block_tables is not None
assert self.context_lens is not None
return self
# Used in the PallasAttentionBackendImpl
slot_mapping: torch.Tensor
block_tables: torch.Tensor
context_lens: torch.Tensor
query_start_loc: torch.Tensor
num_seqs: int
class PallasAttentionBackendImpl(AttentionImpl):
@ -105,10 +84,13 @@ class PallasAttentionBackendImpl(AttentionImpl):
logits_soft_cap: Optional[float] = None,
attn_type: str = AttentionType.DECODER,
) -> None:
if blocksparse_params is not None:
raise ValueError("Paged attention Pallas kernel does "
"not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.num_kv_heads = num_kv_heads
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
@ -126,25 +108,6 @@ class PallasAttentionBackendImpl(AttentionImpl):
raise NotImplementedError(
"Attention logits soft-capping is not supported.")
if torch_xla.tpu.version() < 4:
raise NotImplementedError("TPU version must be 4 or higher.")
self.megacore_mode = None
tpu_env = torch_xla.tpu.get_tpu_env()
tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None)
or tpu_env.get("TYPE", None)
or tpu_env.get("TPU_ACCELERATOR_TYPE", None))
assert tpu_type is not None
tpu_type = tpu_type.lower()
if (("lite" not in tpu_type) and ("v6" not in tpu_type)):
if self.num_kv_heads % 2 == 0:
self.megacore_mode = "kv_head"
else:
# NOTE(woosuk): If the batch size is not a multiple of 2, the
# megacore mode will be None.
self.megacore_mode = "batch"
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
@ -164,135 +127,47 @@ class PallasAttentionBackendImpl(AttentionImpl):
"""Forward pass with Pallas attention.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size]
kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size]
NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor
with shape [0] for profiling run.
query: shape = [num_tokens, num_heads * head_size]
key: shape = [num_tokens, num_kv_heads * head_size]
value: shape = [num_tokens, num_kv_heads * head_size]
kv_cache = ([num_kv_heads, num_blocks, block_size, head_size],
[num_kv_heads, num_blocks, block_size, head_size])
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
if attn_metadata is None:
# For determine_available_memory case.
if kv_cache[0].numel() == 0:
if output is None:
output = torch.ones_like(query)
return output
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
batch_size, seq_len, hidden_size = query.shape
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
value = value.view(batch_size, seq_len, self.num_kv_heads,
self.head_size)
num_tokens, hidden_size = query.shape
query = query.view(num_tokens, self.num_heads, self.head_size)
key = key.view(num_tokens, self.num_kv_heads, self.head_size)
value = value.view(num_tokens, self.num_kv_heads, self.head_size)
key_cache, value_cache = kv_cache
if kv_cache[0].numel() > 0:
slot_mapping = attn_metadata.slot_mapping
key_cache, value_cache = kv_cache
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
query = query * self.scale
if attn_metadata.num_prefills > 0:
if attn_metadata.block_tables is None:
# Prefill without paged KV cache.
assert seq_len % 16 == 0, (
"Pallas FlashAttention kernel requires seq_len to be a "
f"multiple of 16 but got {seq_len}")
output = torch.ops.xla.ragged_paged_attention(
query,
key_cache,
value_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
attn_metadata.query_start_loc,
attn_metadata.num_seqs,
num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK,
num_queries_per_block=NUM_QUERIES_PER_BLOCK,
use_kernel=False,
)
# Handle GQA/MQA.
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv,
dim=-2)
key = key.view(batch_size, seq_len, self.num_heads,
self.head_size)
value = value.repeat_interleave(self.num_queries_per_kv,
dim=-2)
value = value.view(batch_size, seq_len, self.num_heads,
self.head_size)
# FlashAttention kernel requires the input shape to be
# [batch_size, num_heads, seq_len, d_model]
# while the input is [batch_size, seq_len, num_heads, d_model].
# Permute the input to match the required format.
output = torch.ops.xla.flash_attention(
query.permute(0, 2, 1, 3),
key.permute(0, 2, 1, 3),
value.permute(0, 2, 1, 3),
True,
)
output = output.permute(0, 2, 1, 3)
else:
# Prefill with paged KV cache.
# TODO(woosuk): Tune the below knobs.
num_kv_pages_per_compute_block = 16
num_queries_per_compute_block = 16
assert seq_len % num_queries_per_compute_block == 0
output = torch.ops.xla.multi_queries_paged_attention(
query,
key_cache,
value_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
attn_metadata.effective_query_lens,
num_kv_pages_per_compute_block,
num_queries_per_compute_block,
use_kernel=True,
)
else:
# Decoding run.
assert kv_cache[0].numel() > 0
query = query.squeeze(dim=1)
pages_per_compute_block = 16 # TODO(woosuk): Tune this value.
assert attn_metadata.block_tables is not None
assert attn_metadata.context_lens is not None
# NOTE(woosuk): The PagedAttention Pallas kernel stores the entire
# block table in SMEM. Therefore, if the block table is too large,
# the kernel compilation will fail. To avoid this, we split the
# batch dimension into smaller chunks and run the kernel multiple
# times.
MAX_SMEM_USAGE = 512 * 1024
size_per_seq = 4 * attn_metadata.block_tables.shape[1]
max_num_seq = MAX_SMEM_USAGE // size_per_seq
if batch_size <= max_num_seq:
output = paged_attention(
query,
key_cache,
value_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
pages_per_compute_block,
self.megacore_mode,
)
else:
chunk_size = max_num_seq
# Make sure the chunk size is a multiple of 2.
chunk_size = chunk_size // 2 * 2
num_chunks = (batch_size + chunk_size - 1) // chunk_size
output = torch.empty_like(query)
for chunk_idx in range(num_chunks):
chunk_start = chunk_idx * chunk_size
chunk_end = chunk_start + chunk_size
# NOTE(woosuk): We skip this line because it causes Dynamo
# compilation error. Instead, we rely on the slice operation
# to handle the out-of-bound case.
# chunk_end = min(chunk_end, batch_size)
chunk_output = paged_attention(
query[chunk_start:chunk_end],
key_cache,
value_cache,
attn_metadata.context_lens[chunk_start:chunk_end],
attn_metadata.block_tables[chunk_start:chunk_end],
pages_per_compute_block,
self.megacore_mode,
)
output[chunk_start:chunk_end] = chunk_output
# Reshape the output tensor.
return output.reshape(batch_size, seq_len, hidden_size)
return output.reshape(num_tokens, hidden_size)
def write_to_kv_cache(
@ -302,52 +177,21 @@ def write_to_kv_cache(
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
""" Write the key and values to the KV cache.
Args:
key: shape = [num_tokens, num_kv_heads, head_size]
value: shape = [num_tokens, num_kv_heads, head_size]
k_cache = [num_kv_heads, num_blocks, block_size, head_size]
v_cache = [num_kv_heads, num_blocks, block_size, head_size]
"""
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)
key = key.flatten(0, 2)
value = value.flatten(0, 2)
key = key.flatten(0, 1)
value = value.flatten(0, 1)
key_cache = key_cache.flatten(0, 2)
value_cache = value_cache.flatten(0, 2)
key_cache.index_copy_(0, slot_mapping, key)
value_cache.index_copy_(0, slot_mapping, value)
def paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
context_lens: torch.Tensor,
block_tables: torch.Tensor,
pages_per_compute_block: int,
megacore_mode: Optional[str],
) -> torch.Tensor:
batch_size = query.shape[0]
if megacore_mode == "batch" and batch_size % 2 != 0:
megacore_mode = None
else:
megacore_mode = megacore_mode
# NOTE(woosuk): A temporary workaround to avoid the error:
# "xla::paged_attention() Expected a value of type 'str' for
# argument 'megacore_mode' but instead found type 'NoneType'."
if megacore_mode is not None:
output = torch.ops.xla.paged_attention(
query,
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
megacore_mode=megacore_mode,
)
else:
output = torch.ops.xla.paged_attention(
query,
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
)
return output

View File

@ -79,4 +79,4 @@ class ModelRunnerOutput:
# [prompt_len, num_prompt_logprobs]
# [prompt_len, num_prompt_logprobs]
# [prompt_len]
prompt_logprobs_dict: Dict[str, LogprobsTensors]
prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]]

View File

@ -1071,12 +1071,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self,
hidden_states: torch.Tensor,
scheduler_output: "SchedulerOutput",
) -> Dict[str, LogprobsTensors]:
) -> Dict[str, Optional[LogprobsTensors]]:
num_prompt_logprobs_dict = self.input_batch.num_prompt_logprobs
if not num_prompt_logprobs_dict:
return {}
prompt_logprobs_dict: Dict[str, LogprobsTensors] = {}
prompt_logprobs_dict: Dict[str, Optional[LogprobsTensors]] = {}
# Since prompt logprobs are a rare feature, prioritize simple,
# maintainable loop over optimal performance.

File diff suppressed because it is too large Load Diff

View File

@ -21,7 +21,7 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.tpu_model_runner import ExecutionMode, TPUModelRunner
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
logger = init_logger(__name__)
@ -126,9 +126,7 @@ class TPUWorker:
self.model_runner.dummy_run(
runner_kv_caches,
num_tokens=1,
seq_len=self.scheduler_config.max_num_batched_tokens,
exec_mode=ExecutionMode.PREFILL,
num_tokens=self.scheduler_config.max_num_batched_tokens,
)
# Synchronize before measuring the memory usage.