[TPU] Support sliding window and logit soft capping in the paged attention kernel for TPU. (#15732)
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com>
This commit is contained in:
parent
03a70eacaf
commit
b6be6f8d1e
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
set -xue
|
||||
|
||||
# Build the docker image.
|
||||
docker build -f docker/Dockerfile.tpu -t vllm-tpu .
|
||||
@ -38,7 +38,9 @@ docker run --privileged --net host --shm-size=16G -it \
|
||||
&& echo TEST_7 \
|
||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py \
|
||||
&& echo TEST_8 \
|
||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py" \
|
||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py \
|
||||
&& echo TEST_9 \
|
||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py" \
|
||||
|
||||
|
||||
# TODO: This test fails because it uses RANDOM_SEED sampling
|
||||
|
@ -13,18 +13,24 @@ import pytest
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct"
|
||||
MODEL_NAMES = [
|
||||
"Qwen/Qwen2-1.5B-Instruct",
|
||||
"google/gemma-3-1b-it",
|
||||
]
|
||||
NUM_CONCURRENT = 500
|
||||
TASK = "gsm8k"
|
||||
FILTER = "exact_match,strict-match"
|
||||
RTOL = 0.03
|
||||
EXPECTED_VALUE = 0.58
|
||||
EXPECTED_VALUES = {
|
||||
"Qwen/Qwen2-1.5B-Instruct": 0.58,
|
||||
"google/gemma-3-1b-it": 0.25,
|
||||
}
|
||||
|
||||
|
||||
def run_test(more_args=None):
|
||||
def run_test(model_name, more_args=None):
|
||||
"""Run the end to end accuracy test."""
|
||||
|
||||
model_args = f"pretrained={MODEL_NAME},max_model_len=4096"
|
||||
model_args = f"pretrained={model_name},max_model_len=4096"
|
||||
|
||||
if more_args is not None:
|
||||
model_args = "{},{}".format(model_args, more_args)
|
||||
@ -37,9 +43,12 @@ def run_test(more_args=None):
|
||||
)
|
||||
|
||||
measured_value = results["results"][TASK][FILTER]
|
||||
assert (measured_value - RTOL < EXPECTED_VALUE
|
||||
and measured_value + RTOL > EXPECTED_VALUE
|
||||
), f"Expected: {EXPECTED_VALUE} | Measured: {measured_value}"
|
||||
assert model_name in EXPECTED_VALUES, (
|
||||
f"Cannot find the expected value for the model {model_name=}")
|
||||
expected_value = EXPECTED_VALUES[model_name]
|
||||
assert (measured_value - RTOL < expected_value
|
||||
and measured_value + RTOL > expected_value
|
||||
), f"Expected: {expected_value} | Measured: {measured_value}"
|
||||
|
||||
|
||||
# TODO: [AlexM] Fix it with new CI/CD tests
|
||||
@ -49,7 +58,8 @@ TPU_TP_TEST_STR = "" #"tensor_parallel_size=4"
|
||||
@pytest.mark.skipif(not current_platform.is_cuda()
|
||||
and not current_platform.is_tpu(),
|
||||
reason="V1 is currently only supported on CUDA and TPU")
|
||||
def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch):
|
||||
@pytest.mark.parametrize("model", MODEL_NAMES)
|
||||
def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Run with the V1 Engine."""
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
@ -64,7 +74,7 @@ def test_lm_eval_accuracy_v1_engine(monkeypatch: pytest.MonkeyPatch):
|
||||
if TPU_TP_TEST_STR:
|
||||
more_args += ",{}".format(TPU_TP_TEST_STR)
|
||||
|
||||
run_test(more_args)
|
||||
run_test(model, more_args)
|
||||
|
||||
|
||||
def test_lm_eval_accuracy_v0_engine(monkeypatch: pytest.MonkeyPatch):
|
||||
@ -72,4 +82,4 @@ def test_lm_eval_accuracy_v0_engine(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "0")
|
||||
run_test()
|
||||
run_test("Qwen/Qwen2-1.5B-Instruct")
|
||||
|
98
tests/v1/tpu/test_pallas.py
Normal file
98
tests/v1/tpu/test_pallas.py
Normal file
@ -0,0 +1,98 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from unittest.mock import ANY, patch
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
|
||||
NUM_QUERIES_PER_BLOCK,
|
||||
PallasAttentionBackendImpl,
|
||||
PallasMetadata)
|
||||
|
||||
|
||||
def test_ragged_paged_attention():
|
||||
# We verify that the kernel inputs such as sliding_window, etc. are passed
|
||||
# in from the model correctly.
|
||||
# The correctness of the paged attention kernel is tested in the kernel
|
||||
# library.
|
||||
num_heads = 4
|
||||
head_size = 128
|
||||
scale = 1.0
|
||||
num_kv_heads = 4
|
||||
sliding_window = 128
|
||||
logits_soft_cap = 50.0
|
||||
attn_impl = PallasAttentionBackendImpl(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
alibi_slopes=None,
|
||||
sliding_window=sliding_window,
|
||||
kv_cache_dtype="auto",
|
||||
logits_soft_cap=logits_soft_cap,
|
||||
attn_type=AttentionType.DECODER,
|
||||
)
|
||||
mock_vmem_limit_bytes = 1024
|
||||
attn_impl.vmem_limit_bytes = mock_vmem_limit_bytes
|
||||
|
||||
class FakeAttentionLayer:
|
||||
_k_scale_float: float
|
||||
_v_scale_float: float
|
||||
|
||||
layer = FakeAttentionLayer()
|
||||
layer._k_scale_float = 1.0
|
||||
layer._v_scale_float = 1.0
|
||||
|
||||
num_tokens = 16
|
||||
num_blocks = 1024
|
||||
block_size = 16
|
||||
query = torch.zeros(num_tokens, num_heads * head_size)
|
||||
key = torch.zeros(num_tokens, num_kv_heads * head_size)
|
||||
value = torch.zeros(num_tokens, num_kv_heads * head_size)
|
||||
kv_cache = torch.zeros(num_blocks, block_size, num_kv_heads * 2, head_size)
|
||||
slot_mapping = torch.zeros(num_tokens, dtype=torch.int64)
|
||||
max_num_reqs = 8
|
||||
max_num_blocks_per_req = 8
|
||||
block_tables = torch.zeros((max_num_reqs, max_num_blocks_per_req),
|
||||
dtype=torch.int32)
|
||||
context_lens = torch.ones((max_num_reqs, ), dtype=torch.int32)
|
||||
query_lens = [1] * max_num_reqs
|
||||
query_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
|
||||
dtype=torch.int32),
|
||||
dim=0,
|
||||
dtype=torch.int32)
|
||||
num_seqs = torch.tensor([max_num_reqs], dtype=torch.int32)
|
||||
attn_metadata = PallasMetadata(
|
||||
slot_mapping=slot_mapping,
|
||||
block_tables=block_tables,
|
||||
context_lens=context_lens,
|
||||
query_start_loc=query_start_loc,
|
||||
num_seqs=num_seqs,
|
||||
)
|
||||
|
||||
with patch("torch.ops.xla.ragged_paged_attention"
|
||||
) as mock_ragged_paged_attention:
|
||||
attn_impl.forward(
|
||||
layer=layer,
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
mock_ragged_paged_attention.assert_called_once_with(
|
||||
ANY, # query
|
||||
ANY, # kv_cache
|
||||
ANY, # context_lens
|
||||
ANY, # block_tables
|
||||
ANY, # query_start_loc
|
||||
ANY, # num_seqs
|
||||
num_kv_pages_per_block=NUM_KV_PAGES_PER_BLOCK,
|
||||
num_queries_per_block=NUM_QUERIES_PER_BLOCK,
|
||||
vmem_limit_bytes=mock_vmem_limit_bytes,
|
||||
use_kernel=True,
|
||||
sm_scale=scale,
|
||||
sliding_window=sliding_window,
|
||||
soft_cap=logits_soft_cap,
|
||||
)
|
@ -92,6 +92,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
@ -99,15 +101,10 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
raise NotImplementedError("Head size must be a multiple of 128.")
|
||||
if alibi_slopes is not None:
|
||||
raise NotImplementedError("Alibi slopes is not supported.")
|
||||
if sliding_window is not None:
|
||||
raise NotImplementedError("Sliding window is not supported.")
|
||||
if kv_cache_dtype != "auto":
|
||||
raise NotImplementedError("FP8 KV cache dtype is not supported.")
|
||||
if blocksparse_params is not None:
|
||||
raise NotImplementedError("Blocksparse is not supported.")
|
||||
if logits_soft_cap is not None:
|
||||
raise NotImplementedError(
|
||||
"Attention logits soft-capping is not supported.")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
@ -172,7 +169,10 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
num_queries_per_block=NUM_QUERIES_PER_BLOCK,
|
||||
vmem_limit_bytes=self.vmem_limit_bytes,
|
||||
use_kernel=True,
|
||||
sm_scale=self.scale)
|
||||
sm_scale=self.scale,
|
||||
sliding_window=self.sliding_window,
|
||||
soft_cap=self.logits_soft_cap,
|
||||
)
|
||||
|
||||
return output.reshape(num_tokens, hidden_size)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user