[V1][TPU] Apply the ragged paged attention kernel fix and remove the padding. (#14846)
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
This commit is contained in:
parent
69698f257e
commit
b4ad56c1bd
@ -17,9 +17,9 @@ ray[data]
|
|||||||
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
||||||
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
||||||
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||||
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250306%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250314%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||||
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250306%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250314%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||||
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250306%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.8.0.dev20250314%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250306%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250314%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250306%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250314%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250306%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250314%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||||
|
@ -23,8 +23,7 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
|
|||||||
from vllm.sampling_params import SamplingType
|
from vllm.sampling_params import SamplingType
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
|
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
|
||||||
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
|
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
|
||||||
PallasAttentionBackend,
|
|
||||||
PallasMetadata)
|
PallasMetadata)
|
||||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
@ -139,10 +138,8 @@ class TPUModelRunner:
|
|||||||
device="cpu")
|
device="cpu")
|
||||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||||
|
|
||||||
padded_max_num_blocks_per_req = _get_padded_number(
|
|
||||||
self.max_num_blocks_per_req, NUM_KV_PAGES_PER_BLOCK)
|
|
||||||
self.block_table_cpu = torch.zeros(
|
self.block_table_cpu = torch.zeros(
|
||||||
(self.max_num_tokens, padded_max_num_blocks_per_req),
|
(self.max_num_tokens, self.max_num_blocks_per_req),
|
||||||
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
|
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
|
||||||
device="cpu")
|
device="cpu")
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user