vllm/tests/kernels/test_attention_selector.py
Russell Bryant e489ad7a21
[Misc] Add SPDX-License-Identifier headers to python source files (#12628)
- **Add SPDX license headers to python source files**
- **Check for SPDX headers using pre-commit**

commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745
Author: Russell Bryant <rbryant@redhat.com>
Date:   Fri Jan 31 14:18:24 2025 -0500

    Add SPDX license headers to python source files
    
This commit adds SPDX license headers to python source files as
recommended to
the project by the Linux Foundation. These headers provide a concise way
that is
both human and machine readable for communicating license information
for each
source file. It helps avoid any ambiguity about the license of the code
and can
    also be easily used by tools to help manage license compliance.
    
The Linux Foundation runs license scans against the codebase to help
ensure
    we are in compliance with the licenses of the code we use, including
dependencies. Having these headers in place helps that tool do its job.
    
    More information can be found on the SPDX site:
    
    - https://spdx.dev/learn/handling-license-info/
    
    Signed-off-by: Russell Bryant <rbryant@redhat.com>

commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea
Author: Russell Bryant <rbryant@redhat.com>
Date:   Fri Jan 31 14:36:32 2025 -0500

    Check for SPDX headers using pre-commit
    
    Signed-off-by: Russell Bryant <rbryant@redhat.com>

---------

Signed-off-by: Russell Bryant <rbryant@redhat.com>
2025-02-02 11:58:18 -08:00

108 lines
4.3 KiB
Python

# SPDX-License-Identifier: Apache-2.0
from unittest.mock import Mock, patch
import pytest
import torch
from tests.kernels.utils import override_backend_env_variable
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.openvino import OpenVinoPlatform
from vllm.platforms.rocm import RocmPlatform
from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL
@pytest.fixture(autouse=True)
def clear_cache():
"""Clear lru cache to ensure each test case runs without caching.
"""
_cached_get_attn_backend.cache_clear()
@pytest.mark.parametrize(
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"])
@pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
def test_env(name: str, device: str, monkeypatch):
"""Test that the attention selector can be set via environment variable.
Note that we do not test FlashAttn because it is the default backend.
"""
override_backend_env_variable(monkeypatch, name)
if device == "cpu":
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
False)
assert backend.get_name() == "TORCH_SDPA"
elif device == "hip":
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
False)
assert backend.get_name() == "ROCM_FLASH"
elif device == "openvino":
with patch("vllm.attention.selector.current_platform",
OpenVinoPlatform()), patch.dict('sys.modules',
{'openvino': Mock()}):
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
False)
assert backend.get_name() == "OPENVINO"
else:
if name in ["XFORMERS", "FLASHINFER"]:
with patch("vllm.attention.selector.current_platform",
CudaPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16,
16, False)
assert backend.get_name() == name
def test_flash_attn(monkeypatch):
"""Test FlashAttn validation."""
# TODO: When testing for v1, pipe in `use_v1` as an argument to
# get_attn_backend
override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL)
# Unsupported CUDA arch
with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
backend = get_attn_backend(16, torch.float16, None, 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Unsupported data type
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Unsupported kv cache data type
backend = get_attn_backend(16, torch.float16, "fp8", 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Unsupported block size
backend = get_attn_backend(16, torch.float16, None, 8, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# flash-attn is not installed
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
backend = get_attn_backend(16, torch.float16, None, 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Unsupported head size
backend = get_attn_backend(17, torch.float16, None, 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Attention-free models should bypass env and use PlaceholderAttention
backend = get_attn_backend(16, torch.float16, torch.float16, 16, True)
assert backend.get_name() != STR_FLASH_ATTN_VAL
def test_invalid_env(monkeypatch):
"""Ignore the invalid env variable if it is set."""
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
backend = get_attn_backend(32, torch.float16, None, 16, False)
assert backend.get_name() == "FLASH_ATTN"
# when block size == 16, backend will fall back to XFORMERS
backend = get_attn_backend(16, torch.float16, None, 16, False)
assert backend.get_name() == "XFORMERS"