diff --git a/requirements/tpu.txt b/requirements/tpu.txt index 97a39bcd..7246fc19 100644 --- a/requirements/tpu.txt +++ b/requirements/tpu.txt @@ -17,9 +17,9 @@ ray[data] --find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250306%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" -torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250306%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" -torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250306%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250306%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250306%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250306%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" +torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250314%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" +torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250314%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" +torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250314%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250314%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250314%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250314%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index effcac7e..00869467 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -23,8 +23,7 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available -from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK, - PallasAttentionBackend, +from vllm.v1.attention.backends.pallas import (PallasAttentionBackend, PallasMetadata) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, @@ -139,10 +138,8 @@ class TPUModelRunner: device="cpu") self.slot_mapping_np = self.slot_mapping_cpu.numpy() - padded_max_num_blocks_per_req = _get_padded_number( - self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK) self.block_table_cpu = torch.zeros( - (self.max_num_tokens, padded_max_num_blocks_per_req), + (self.max_num_tokens, self.max_num_blocks_per_req), dtype=self.input_batch.block_table.get_cpu_tensor().dtype, device="cpu")