[Bugfix] Fix broken kernel test due to missing rename for v1 Triton backend (#15282)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
91ca929dc7
commit
8afcd0f633
@ -49,7 +49,7 @@ def test_env(
|
|||||||
RocmPlatform()):
|
RocmPlatform()):
|
||||||
backend = get_attn_backend(16, torch.float16, torch.float16,
|
backend = get_attn_backend(16, torch.float16, torch.float16,
|
||||||
16, False)
|
16, False)
|
||||||
EXPECTED = "ROCM_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH"
|
EXPECTED = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH"
|
||||||
assert backend.get_name() == EXPECTED
|
assert backend.get_name() == EXPECTED
|
||||||
elif device == "openvino":
|
elif device == "openvino":
|
||||||
with patch("vllm.attention.selector.current_platform",
|
with patch("vllm.attention.selector.current_platform",
|
||||||
|
@ -26,7 +26,7 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
|
|||||||
# Test standard ROCm attention
|
# Test standard ROCm attention
|
||||||
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
|
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
|
||||||
assert (backend.get_name() == "ROCM_FLASH"
|
assert (backend.get_name() == "ROCM_FLASH"
|
||||||
or backend.get_name() == "ROCM_ATTN_VLLM_V1")
|
or backend.get_name() == "TRITON_ATTN_VLLM_V1")
|
||||||
|
|
||||||
# mla test for deepseek related
|
# mla test for deepseek related
|
||||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,
|
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user