[TPU] Update PyTorch/XLA (#16288)
Signed-off-by: Chengji Yao <chengjiyao@google.com>
This commit is contained in:
parent
87b4ac56c2
commit
b1eb4ca152
@ -17,10 +17,10 @@ ray[data]
|
||||
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
||||
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
||||
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250406-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250406-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250406-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250406-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250406-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250406-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250408-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250408-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250408-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
|
||||
|
@ -4,9 +4,7 @@ from unittest.mock import ANY, patch
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
|
||||
NUM_QUERIES_PER_BLOCK,
|
||||
PallasAttentionBackendImpl,
|
||||
from vllm.v1.attention.backends.pallas import (PallasAttentionBackendImpl,
|
||||
PallasMetadata)
|
||||
|
||||
|
||||
@ -32,8 +30,6 @@ def test_ragged_paged_attention():
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
attn_type=AttentionType.DECODER,
|
||||
)
|
||||
mock_vmem_limit_bytes = 1024
|
||||
attn_impl.vmem_limit_bytes = mock_vmem_limit_bytes
|
||||
|
||||
class FakeAttentionLayer:
|
||||
_k_scale_float: float
|
||||
@ -88,9 +84,9 @@ def test_ragged_paged_attention():
|
||||
ANY, # block_tables
|
||||
ANY, # query_start_loc
|
||||
ANY, # num_seqs
|
||||
num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK,
|
||||
num_queries_per_block=NUM_QUERIES_PER_BLOCK,
|
||||
vmem_limit_bytes=mock_vmem_limit_bytes,
|
||||
num_kv_pages_per_block=None,
|
||||
num_queries_per_block=None,
|
||||
vmem_limit_bytes=None,
|
||||
use_kernel=True,
|
||||
sm_scale=scale,
|
||||
sliding_window=sliding_window,
|
||||
|
Loading…
x
Reference in New Issue
Block a user