[TPU] Update PyTorch/XLA (#16288)

Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
Chengji Yao 2025-04-08 23:46:32 -07:00 committed by GitHub
parent 87b4ac56c2
commit b1eb4ca152
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 14 deletions

View File

@ -17,10 +17,10 @@ ray[data]
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250406-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250406-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250406-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250406-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250406-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250406-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250408-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250408-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250408-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"

View File

@ -4,9 +4,7 @@ from unittest.mock import ANY, patch
import torch
from vllm.attention.backends.abstract import AttentionType
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
NUM_QUERIES_PER_BLOCK,
PallasAttentionBackendImpl,
from vllm.v1.attention.backends.pallas import (PallasAttentionBackendImpl,
PallasMetadata)
@ -32,8 +30,6 @@ def test_ragged_paged_attention():
logits_soft_cap=logits_soft_cap,
attn_type=AttentionType.DECODER,
)
mock_vmem_limit_bytes = 1024
attn_impl.vmem_limit_bytes = mock_vmem_limit_bytes
class FakeAttentionLayer:
_k_scale_float: float
@ -88,9 +84,9 @@ def test_ragged_paged_attention():
ANY, # block_tables
ANY, # query_start_loc
ANY, # num_seqs
num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK,
num_queries_per_block=NUM_QUERIES_PER_BLOCK,
vmem_limit_bytes=mock_vmem_limit_bytes,
num_kv_pages_per_block=None,
num_queries_per_block=None,
vmem_limit_bytes=None,
use_kernel=True,
sm_scale=scale,
sliding_window=sliding_window,