
- **Add SPDX license headers to python source files** - **Check for SPDX headers using pre-commit** commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745 Author: Russell Bryant <rbryant@redhat.com> Date: Fri Jan 31 14:18:24 2025 -0500 Add SPDX license headers to python source files This commit adds SPDX license headers to python source files as recommended to the project by the Linux Foundation. These headers provide a concise way that is both human and machine readable for communicating license information for each source file. It helps avoid any ambiguity about the license of the code and can also be easily used by tools to help manage license compliance. The Linux Foundation runs license scans against the codebase to help ensure we are in compliance with the licenses of the code we use, including dependencies. Having these headers in place helps that tool do its job. More information can be found on the SPDX site: - https://spdx.dev/learn/handling-license-info/ Signed-off-by: Russell Bryant <rbryant@redhat.com> commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea Author: Russell Bryant <rbryant@redhat.com> Date: Fri Jan 31 14:36:32 2025 -0500 Check for SPDX headers using pre-commit Signed-off-by: Russell Bryant <rbryant@redhat.com> --------- Signed-off-by: Russell Bryant <rbryant@redhat.com>
108 lines
4.3 KiB
Python
108 lines
4.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
from unittest.mock import Mock, patch
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from tests.kernels.utils import override_backend_env_variable
|
|
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
|
|
from vllm.platforms.cpu import CpuPlatform
|
|
from vllm.platforms.cuda import CudaPlatform
|
|
from vllm.platforms.openvino import OpenVinoPlatform
|
|
from vllm.platforms.rocm import RocmPlatform
|
|
from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def clear_cache():
|
|
"""Clear lru cache to ensure each test case runs without caching.
|
|
"""
|
|
_cached_get_attn_backend.cache_clear()
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"])
|
|
@pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
|
|
def test_env(name: str, device: str, monkeypatch):
|
|
"""Test that the attention selector can be set via environment variable.
|
|
Note that we do not test FlashAttn because it is the default backend.
|
|
"""
|
|
|
|
override_backend_env_variable(monkeypatch, name)
|
|
|
|
if device == "cpu":
|
|
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
|
|
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
|
|
False)
|
|
assert backend.get_name() == "TORCH_SDPA"
|
|
elif device == "hip":
|
|
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
|
|
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
|
|
False)
|
|
assert backend.get_name() == "ROCM_FLASH"
|
|
elif device == "openvino":
|
|
with patch("vllm.attention.selector.current_platform",
|
|
OpenVinoPlatform()), patch.dict('sys.modules',
|
|
{'openvino': Mock()}):
|
|
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
|
|
False)
|
|
assert backend.get_name() == "OPENVINO"
|
|
else:
|
|
if name in ["XFORMERS", "FLASHINFER"]:
|
|
with patch("vllm.attention.selector.current_platform",
|
|
CudaPlatform()):
|
|
backend = get_attn_backend(16, torch.float16, torch.float16,
|
|
16, False)
|
|
assert backend.get_name() == name
|
|
|
|
|
|
def test_flash_attn(monkeypatch):
|
|
"""Test FlashAttn validation."""
|
|
# TODO: When testing for v1, pipe in `use_v1` as an argument to
|
|
# get_attn_backend
|
|
|
|
override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL)
|
|
|
|
# Unsupported CUDA arch
|
|
with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
|
|
backend = get_attn_backend(16, torch.float16, None, 16, False)
|
|
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
|
|
|
# Unsupported data type
|
|
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False)
|
|
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
|
|
|
# Unsupported kv cache data type
|
|
backend = get_attn_backend(16, torch.float16, "fp8", 16, False)
|
|
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
|
|
|
# Unsupported block size
|
|
backend = get_attn_backend(16, torch.float16, None, 8, False)
|
|
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
|
|
|
# flash-attn is not installed
|
|
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
|
|
backend = get_attn_backend(16, torch.float16, None, 16, False)
|
|
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
|
|
|
# Unsupported head size
|
|
backend = get_attn_backend(17, torch.float16, None, 16, False)
|
|
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
|
|
|
# Attention-free models should bypass env and use PlaceholderAttention
|
|
backend = get_attn_backend(16, torch.float16, torch.float16, 16, True)
|
|
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
|
|
|
|
|
def test_invalid_env(monkeypatch):
|
|
"""Ignore the invalid env variable if it is set."""
|
|
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
|
|
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
|
|
backend = get_attn_backend(32, torch.float16, None, 16, False)
|
|
assert backend.get_name() == "FLASH_ATTN"
|
|
|
|
# when block size == 16, backend will fall back to XFORMERS
|
|
backend = get_attn_backend(16, torch.float16, None, 16, False)
|
|
assert backend.get_name() == "XFORMERS"
|