[FEAT] [ROCm] Add AITER int8 scaled gemm kernel (#15433)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
parent
73aa7041bf
commit
4965ec42d2
@ -20,6 +20,23 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|||||||
sparse_cutlass_supported)
|
sparse_cutlass_supported)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
# AITER only supports per-channel-per-channel INT8 gemm
|
||||||
|
# and per-tensor-per-tensor INT8 GEMM.
|
||||||
|
# It does not support mix precision MM and mix quantization scheme.
|
||||||
|
ROCM_AITER_SUPPORTED_INT8_MODEL = [
|
||||||
|
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
|
||||||
|
"nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2"
|
||||||
|
]
|
||||||
|
|
||||||
|
# TritonScaledMMLinearKernel only supports symmetric quantization.
|
||||||
|
ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL = [
|
||||||
|
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
|
||||||
|
"nm-testing/tinyllama-oneshot-w8-channel-a8-tensor",
|
||||||
|
"neuralmagic/Llama-3.2-1B-quantized.w8a8",
|
||||||
|
"nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2",
|
||||||
|
"nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
def use_v0_only(monkeypatch):
|
def use_v0_only(monkeypatch):
|
||||||
@ -57,6 +74,11 @@ def use_v0_only(monkeypatch):
|
|||||||
)
|
)
|
||||||
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
|
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
|
||||||
model_path, strategy, quant_type, shape_0, is_symmetric = model_args
|
model_path, strategy, quant_type, shape_0, is_symmetric = model_args
|
||||||
|
|
||||||
|
if current_platform.is_rocm(
|
||||||
|
) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL:
|
||||||
|
pytest.skip(f"Skip model {model_path} as it is not support on ROCm.")
|
||||||
|
|
||||||
with vllm_runner(model_path, enforce_eager=True) as llm:
|
with vllm_runner(model_path, enforce_eager=True) as llm:
|
||||||
|
|
||||||
def check_model(model):
|
def check_model(model):
|
||||||
@ -123,6 +145,8 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
|
|||||||
)
|
)
|
||||||
@pytest.mark.parametrize("max_tokens", [32])
|
@pytest.mark.parametrize("max_tokens", [32])
|
||||||
@pytest.mark.parametrize("num_logprobs", [10])
|
@pytest.mark.parametrize("num_logprobs", [10])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"use_aiter", [True, False] if current_platform.is_rocm() else [False])
|
||||||
def test_compressed_tensors_w8a8_logprobs(
|
def test_compressed_tensors_w8a8_logprobs(
|
||||||
hf_runner,
|
hf_runner,
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
@ -130,7 +154,21 @@ def test_compressed_tensors_w8a8_logprobs(
|
|||||||
model_path,
|
model_path,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
num_logprobs,
|
num_logprobs,
|
||||||
|
use_aiter,
|
||||||
|
monkeypatch,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
if current_platform.is_rocm(
|
||||||
|
) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL:
|
||||||
|
pytest.skip(f"Skip model {model_path} as it is not support on ROCm.")
|
||||||
|
|
||||||
|
if use_aiter:
|
||||||
|
if model_path not in ROCM_AITER_SUPPORTED_INT8_MODEL:
|
||||||
|
pytest.skip(
|
||||||
|
f"Skip model {model_path} as it is not support by aiter.")
|
||||||
|
# this will enable VLLM_ROCM_USE_AITER_LINEAR
|
||||||
|
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||||
|
|
||||||
dtype = "bfloat16"
|
dtype = "bfloat16"
|
||||||
|
|
||||||
# skip language translation prompt for the static per tensor asym model
|
# skip language translation prompt for the static per tensor asym model
|
||||||
@ -154,6 +192,9 @@ def test_compressed_tensors_w8a8_logprobs(
|
|||||||
name_1="vllm",
|
name_1="vllm",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|
||||||
def test_compressed_tensors_no_enforce_eager(vllm_runner):
|
def test_compressed_tensors_no_enforce_eager(vllm_runner):
|
||||||
model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
|
model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
|
||||||
@ -177,8 +218,27 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
|
@pytest.mark.parametrize(
|
||||||
|
"use_aiter", [True, False] if current_platform.is_rocm() else [False])
|
||||||
|
def test_compressed_tensors_w8a8_dynamic_per_token(
|
||||||
|
vllm_runner,
|
||||||
|
model_args,
|
||||||
|
use_aiter,
|
||||||
|
monkeypatch,
|
||||||
|
):
|
||||||
model_path, strategy = model_args
|
model_path, strategy = model_args
|
||||||
|
|
||||||
|
if current_platform.is_rocm(
|
||||||
|
) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL:
|
||||||
|
pytest.skip(f"Skip model {model_path} as it is not support on ROCm.")
|
||||||
|
|
||||||
|
if use_aiter:
|
||||||
|
if model_path not in ROCM_AITER_SUPPORTED_INT8_MODEL:
|
||||||
|
pytest.skip(
|
||||||
|
f"Skip model {model_path} as it is not support by aiter.")
|
||||||
|
# this will enable VLLM_ROCM_USE_AITER_LINEAR
|
||||||
|
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||||
|
|
||||||
with vllm_runner(model_path, dtype=torch.float16) as llm:
|
with vllm_runner(model_path, dtype=torch.float16) as llm:
|
||||||
|
|
||||||
def check_model(model):
|
def check_model(model):
|
||||||
@ -207,6 +267,8 @@ def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
|
|||||||
("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4),
|
("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@pytest.mark.skipif(not current_platform.is_cuda(),
|
||||||
|
reason="The tests are skipped on non-CUDA platform.")
|
||||||
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
|
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
|
||||||
model, strategy, group, pack_factor = wNa16_args
|
model, strategy, group, pack_factor = wNa16_args
|
||||||
with vllm_runner(model) as llm:
|
with vllm_runner(model) as llm:
|
||||||
@ -231,6 +293,8 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
|
|||||||
assert output
|
assert output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not current_platform.is_cuda(),
|
||||||
|
reason="This test is skipped on non-CUDA platform.")
|
||||||
def test_compressed_tensors_w4a16_marlin24(vllm_runner):
|
def test_compressed_tensors_w4a16_marlin24(vllm_runner):
|
||||||
model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
|
model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
|
||||||
with vllm_runner(model_path) as llm:
|
with vllm_runner(model_path) as llm:
|
||||||
@ -271,7 +335,7 @@ def test_compressed_tensors_fp8(vllm_runner):
|
|||||||
|
|
||||||
if isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8):
|
if isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8):
|
||||||
assert len(qkv_proj.input_scale.shape) == 0
|
assert len(qkv_proj.input_scale.shape) == 0
|
||||||
assert qkv_proj.weight.dtype is torch.float8_e4m3fn
|
assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
|
||||||
assert qkv_proj.weight_scale.dtype is torch.float32
|
assert qkv_proj.weight_scale.dtype is torch.float32
|
||||||
assert len(qkv_proj.weight_scale.shape) == 0
|
assert len(qkv_proj.weight_scale.shape) == 0
|
||||||
|
|
||||||
@ -281,6 +345,8 @@ def test_compressed_tensors_fp8(vllm_runner):
|
|||||||
assert output
|
assert output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not current_platform.is_cuda(),
|
||||||
|
reason="This test is skipped on non-CUDA platform.")
|
||||||
def test_compressed_tensors_kv_cache(vllm_runner):
|
def test_compressed_tensors_kv_cache(vllm_runner):
|
||||||
model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
|
model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
|
||||||
with vllm_runner(model_path, kv_cache_dtype="fp8") as llm:
|
with vllm_runner(model_path, kv_cache_dtype="fp8") as llm:
|
||||||
@ -309,7 +375,8 @@ def _test_2of4_quant_models(qkv_proj,
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not current_platform.has_device_capability(90),
|
not current_platform.is_cuda()
|
||||||
|
or not current_platform.has_device_capability(90),
|
||||||
reason="Sparse FP8 is not yet supported on this GPU type.",
|
reason="Sparse FP8 is not yet supported on this GPU type.",
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -356,7 +423,8 @@ def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not current_platform.has_device_capability(90),
|
not current_platform.is_cuda()
|
||||||
|
or not current_platform.has_device_capability(90),
|
||||||
reason="Sparse FP8 is not yet supported on this GPU type.",
|
reason="Sparse FP8 is not yet supported on this GPU type.",
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -75,6 +75,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_DISABLED_KERNELS: list[str] = []
|
VLLM_DISABLED_KERNELS: list[str] = []
|
||||||
VLLM_USE_V1: bool = True
|
VLLM_USE_V1: bool = True
|
||||||
VLLM_ROCM_USE_AITER: bool = False
|
VLLM_ROCM_USE_AITER: bool = False
|
||||||
|
VLLM_ROCM_USE_AITER_LINEAR: bool = True
|
||||||
VLLM_ROCM_USE_AITER_MOE: bool = True
|
VLLM_ROCM_USE_AITER_MOE: bool = True
|
||||||
VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE: bool = False
|
VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE: bool = False
|
||||||
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
|
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
|
||||||
@ -524,6 +525,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in
|
lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in
|
||||||
("true", "1")),
|
("true", "1")),
|
||||||
|
|
||||||
|
# use aiter linear op if aiter ops are enabled
|
||||||
|
# The following list of related ops
|
||||||
|
# - scaled_mm (per-tensor / rowwise)
|
||||||
|
"VLLM_ROCM_USE_AITER_LINEAR":
|
||||||
|
lambda: (os.getenv("VLLM_ROCM_USE_AITER_LINEAR", "True").lower() in
|
||||||
|
("true", "1")),
|
||||||
|
|
||||||
# Whether to use aiter moe ops.
|
# Whether to use aiter moe ops.
|
||||||
# By default is enabled.
|
# By default is enabled.
|
||||||
"VLLM_ROCM_USE_AITER_MOE":
|
"VLLM_ROCM_USE_AITER_MOE":
|
||||||
|
@ -3,6 +3,8 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Dict, List, Optional, Type
|
from typing import Dict, List, Optional, Type
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
|
||||||
|
AiterScaledMMLinearKernel)
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
|
||||||
CutlassScaledMMLinearKernel)
|
CutlassScaledMMLinearKernel)
|
||||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
|
||||||
@ -17,7 +19,7 @@ from vllm.platforms import PlatformEnum, current_platform
|
|||||||
_POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = {
|
_POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = {
|
||||||
PlatformEnum.CPU: [CutlassScaledMMLinearKernel],
|
PlatformEnum.CPU: [CutlassScaledMMLinearKernel],
|
||||||
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
|
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel],
|
||||||
PlatformEnum.ROCM: [TritonScaledMMLinearKernel],
|
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
|
||||||
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
|
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -0,0 +1,119 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
from .cutlass import CutlassScaledMMLinearKernel
|
||||||
|
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
|
||||||
|
|
||||||
|
|
||||||
|
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
return 90
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def can_implement(
|
||||||
|
cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||||
|
if not current_platform.is_rocm():
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"AiterScaledMMLinearKernel requires `aiter` which is not " +
|
||||||
|
"currently supported on non-ROCm platform.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import aiter # noqa: F401 # deliberately attempt to import aiter
|
||||||
|
except Exception:
|
||||||
|
return (
|
||||||
|
False,
|
||||||
|
"AiterScaledMMLinearKernel requires `aiter` which is not " +
|
||||||
|
"installed on ROCm.")
|
||||||
|
# Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled
|
||||||
|
if not (
|
||||||
|
envs.VLLM_ROCM_USE_AITER_LINEAR \
|
||||||
|
and envs.VLLM_ROCM_USE_AITER
|
||||||
|
):
|
||||||
|
return (False, "AiterScaledMMLinearKernel is disabled. " +
|
||||||
|
"Enable by setting `VLLM_ROCM_USE_AITER=1` " +
|
||||||
|
"and `VLLM_ROCM_USE_AITER_LINEAR=1`. " +
|
||||||
|
"`VLLM_ROCM_USE_AITER_LINEAR` default is True.")
|
||||||
|
|
||||||
|
if not c.input_symmetric:
|
||||||
|
return (False,
|
||||||
|
"AiterScaledMMLinearKernel only supports symmetric " +
|
||||||
|
"quantization.")
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
super().process_weights_after_loading(layer)
|
||||||
|
|
||||||
|
def apply_weights(self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
`AiterScaledMMLinearKernel` implements a fused version of
|
||||||
|
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
|
||||||
|
where scale_a * a and scale_b * b are implemented using numpy-style
|
||||||
|
broadcasting.
|
||||||
|
Currently only support per-tensor-per-tensor GEMM
|
||||||
|
and per-token-per-channel GEMM through AITER
|
||||||
|
w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support
|
||||||
|
ATIER block scaled GEMM and mix-precision GEMM.
|
||||||
|
"""
|
||||||
|
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
|
||||||
|
|
||||||
|
# ops.scaled_int8_quant supports both dynamic and static quant:
|
||||||
|
# * dynamic, i_s is None and x_s computed from x.
|
||||||
|
# * static, i_s is scalar and x_s is i_s.
|
||||||
|
symmetric = azp_adj is None
|
||||||
|
assert symmetric, ("AiterScaledMMLinearKernel only supports"
|
||||||
|
" symmetric quantization.")
|
||||||
|
x_q, x_s, x_zp = ops.scaled_int8_quant(x,
|
||||||
|
i_s,
|
||||||
|
i_zp,
|
||||||
|
symmetric=symmetric)
|
||||||
|
|
||||||
|
assert x_zp is None, ("AiterScaledMMLinearKernel only supports"
|
||||||
|
" symmetric quantization.")
|
||||||
|
out_dtype = x.dtype
|
||||||
|
|
||||||
|
assert (w_q.shape[0] % 16 == 0 and w_q.shape[1] % 16 == 0)
|
||||||
|
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
|
||||||
|
assert bias is None or bias.shape[0] == w_q.shape[
|
||||||
|
1] and bias.dtype == out_dtype
|
||||||
|
|
||||||
|
m = x_q.shape[0] # a
|
||||||
|
n = w_q.shape[1] # b
|
||||||
|
|
||||||
|
per_tensor_scale_a = (x_s.numel() == 1)
|
||||||
|
per_tensor_scale_b = (w_s.numel() == 1)
|
||||||
|
per_token_scale_a = (x_s.numel() == m)
|
||||||
|
per_channel_scale_b = (w_s.numel() == n)
|
||||||
|
|
||||||
|
# @TODO:
|
||||||
|
# Maybe broadcast the per-tensor-scale into per-channel-scale
|
||||||
|
# if one of the scale is a per-channel-scale.
|
||||||
|
# For now, it only supports:
|
||||||
|
# - per-tensor-per-tensor a8w8 scaled GEMM, and
|
||||||
|
# - per-token-per-channel a8w8 scaled GEMM
|
||||||
|
assert ((per_tensor_scale_a and per_tensor_scale_b)
|
||||||
|
or (per_token_scale_a and per_channel_scale_b)), (
|
||||||
|
"Currently only support per-tensor-per-tensor GEMM " +
|
||||||
|
" and per-token-per-channel GEMM through AITER"
|
||||||
|
" w8a8 scaled gemm. `AiterScaledMMLinearKernel` " +
|
||||||
|
"does not support AITER block scaled GEMM.")
|
||||||
|
|
||||||
|
from aiter import gemm_a8w8_CK
|
||||||
|
|
||||||
|
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
|
||||||
|
# a to be [M, K]
|
||||||
|
# b to be [N, K]
|
||||||
|
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
|
||||||
|
return gemm_a8w8_CK(x_q, w_q.t(), x_s, w_s, bias).to(out_dtype)
|
Loading…
x
Reference in New Issue
Block a user