diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 133475a3..5c928f27 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -20,6 +20,23 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( sparse_cutlass_supported) from vllm.platforms import current_platform +# AITER only supports per-channel-per-channel INT8 gemm +# and per-tensor-per-tensor INT8 GEMM. +# It does not support mix precision MM and mix quantization scheme. +ROCM_AITER_SUPPORTED_INT8_MODEL = [ + "neuralmagic/Llama-3.2-1B-quantized.w8a8", + "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2" +] + +# TritonScaledMMLinearKernel only supports symmetric quantization. +ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL = [ + "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", + "nm-testing/tinyllama-oneshot-w8-channel-a8-tensor", + "neuralmagic/Llama-3.2-1B-quantized.w8a8", + "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", + "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2", +] + @pytest.fixture(scope="function", autouse=True) def use_v0_only(monkeypatch): @@ -57,6 +74,11 @@ def use_v0_only(monkeypatch): ) def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args): model_path, strategy, quant_type, shape_0, is_symmetric = model_args + + if current_platform.is_rocm( + ) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL: + pytest.skip(f"Skip model {model_path} as it is not support on ROCm.") + with vllm_runner(model_path, enforce_eager=True) as llm: def check_model(model): @@ -123,6 +145,8 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args): ) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [10]) +@pytest.mark.parametrize( + "use_aiter", [True, False] if current_platform.is_rocm() else [False]) def test_compressed_tensors_w8a8_logprobs( hf_runner, vllm_runner, @@ -130,7 +154,21 @@ def test_compressed_tensors_w8a8_logprobs( model_path, max_tokens, num_logprobs, + use_aiter, + monkeypatch, ): + + if current_platform.is_rocm( + ) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL: + pytest.skip(f"Skip model {model_path} as it is not support on ROCm.") + + if use_aiter: + if model_path not in ROCM_AITER_SUPPORTED_INT8_MODEL: + pytest.skip( + f"Skip model {model_path} as it is not support by aiter.") + # this will enable VLLM_ROCM_USE_AITER_LINEAR + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + dtype = "bfloat16" # skip language translation prompt for the static per tensor asym model @@ -154,6 +192,9 @@ def test_compressed_tensors_w8a8_logprobs( name_1="vllm", ) + if current_platform.is_rocm(): + torch.cuda.synchronize() + def test_compressed_tensors_no_enforce_eager(vllm_runner): model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change" @@ -177,8 +218,27 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner): ), ], ) -def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args): +@pytest.mark.parametrize( + "use_aiter", [True, False] if current_platform.is_rocm() else [False]) +def test_compressed_tensors_w8a8_dynamic_per_token( + vllm_runner, + model_args, + use_aiter, + monkeypatch, +): model_path, strategy = model_args + + if current_platform.is_rocm( + ) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL: + pytest.skip(f"Skip model {model_path} as it is not support on ROCm.") + + if use_aiter: + if model_path not in ROCM_AITER_SUPPORTED_INT8_MODEL: + pytest.skip( + f"Skip model {model_path} as it is not support by aiter.") + # this will enable VLLM_ROCM_USE_AITER_LINEAR + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + with vllm_runner(model_path, dtype=torch.float16) as llm: def check_model(model): @@ -207,6 +267,8 @@ def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args): ("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4), ], ) +@pytest.mark.skipif(not current_platform.is_cuda(), + reason="The tests are skipped on non-CUDA platform.") def test_compressed_tensors_wNa16(vllm_runner, wNa16_args): model, strategy, group, pack_factor = wNa16_args with vllm_runner(model) as llm: @@ -231,6 +293,8 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args): assert output +@pytest.mark.skipif(not current_platform.is_cuda(), + reason="This test is skipped on non-CUDA platform.") def test_compressed_tensors_w4a16_marlin24(vllm_runner): model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t" with vllm_runner(model_path) as llm: @@ -271,7 +335,7 @@ def test_compressed_tensors_fp8(vllm_runner): if isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8): assert len(qkv_proj.input_scale.shape) == 0 - assert qkv_proj.weight.dtype is torch.float8_e4m3fn + assert qkv_proj.weight.dtype is current_platform.fp8_dtype() assert qkv_proj.weight_scale.dtype is torch.float32 assert len(qkv_proj.weight_scale.shape) == 0 @@ -281,6 +345,8 @@ def test_compressed_tensors_fp8(vllm_runner): assert output +@pytest.mark.skipif(not current_platform.is_cuda(), + reason="This test is skipped on non-CUDA platform.") def test_compressed_tensors_kv_cache(vllm_runner): model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme" with vllm_runner(model_path, kv_cache_dtype="fp8") as llm: @@ -309,7 +375,8 @@ def _test_2of4_quant_models(qkv_proj, @pytest.mark.skipif( - not current_platform.has_device_capability(90), + not current_platform.is_cuda() + or not current_platform.has_device_capability(90), reason="Sparse FP8 is not yet supported on this GPU type.", ) @pytest.mark.parametrize( @@ -356,7 +423,8 @@ def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4): @pytest.mark.skipif( - not current_platform.has_device_capability(90), + not current_platform.is_cuda() + or not current_platform.has_device_capability(90), reason="Sparse FP8 is not yet supported on this GPU type.", ) @pytest.mark.parametrize( diff --git a/vllm/envs.py b/vllm/envs.py index 53346673..8a03ba32 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -75,6 +75,7 @@ if TYPE_CHECKING: VLLM_DISABLED_KERNELS: list[str] = [] VLLM_USE_V1: bool = True VLLM_ROCM_USE_AITER: bool = False + VLLM_ROCM_USE_AITER_LINEAR: bool = True VLLM_ROCM_USE_AITER_MOE: bool = True VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE: bool = False VLLM_ROCM_USE_AITER_RMSNORM: bool = True @@ -524,6 +525,13 @@ environment_variables: dict[str, Callable[[], Any]] = { lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in ("true", "1")), + # use aiter linear op if aiter ops are enabled + # The following list of related ops + # - scaled_mm (per-tensor / rowwise) + "VLLM_ROCM_USE_AITER_LINEAR": + lambda: (os.getenv("VLLM_ROCM_USE_AITER_LINEAR", "True").lower() in + ("true", "1")), + # Whether to use aiter moe ops. # By default is enabled. "VLLM_ROCM_USE_AITER_MOE": diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py index a5967995..bedda4c2 100644 --- a/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -3,6 +3,8 @@ import os from typing import Dict, List, Optional, Type +from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import ( + AiterScaledMMLinearKernel) from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( CutlassScaledMMLinearKernel) from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 @@ -17,7 +19,7 @@ from vllm.platforms import PlatformEnum, current_platform _POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = { PlatformEnum.CPU: [CutlassScaledMMLinearKernel], PlatformEnum.CUDA: [CutlassScaledMMLinearKernel], - PlatformEnum.ROCM: [TritonScaledMMLinearKernel], + PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel], PlatformEnum.TPU: [XLAScaledMMLinearKernel], } diff --git a/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py new file mode 100644 index 00000000..582b12f7 --- /dev/null +++ b/vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple + +import torch + +import vllm.envs as envs +from vllm import _custom_ops as ops +from vllm.platforms import current_platform + +from .cutlass import CutlassScaledMMLinearKernel +from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig + + +class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 90 + + @classmethod + def can_implement( + cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if not current_platform.is_rocm(): + return ( + False, + "AiterScaledMMLinearKernel requires `aiter` which is not " + + "currently supported on non-ROCm platform.") + + try: + import aiter # noqa: F401 # deliberately attempt to import aiter + except Exception: + return ( + False, + "AiterScaledMMLinearKernel requires `aiter` which is not " + + "installed on ROCm.") + # Check if rocm_aiter_gemm_w8a8_scaled_mm is enabled + if not ( + envs.VLLM_ROCM_USE_AITER_LINEAR \ + and envs.VLLM_ROCM_USE_AITER + ): + return (False, "AiterScaledMMLinearKernel is disabled. " + + "Enable by setting `VLLM_ROCM_USE_AITER=1` " + + "and `VLLM_ROCM_USE_AITER_LINEAR=1`. " + + "`VLLM_ROCM_USE_AITER_LINEAR` default is True.") + + if not c.input_symmetric: + return (False, + "AiterScaledMMLinearKernel only supports symmetric " + + "quantization.") + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + super().process_weights_after_loading(layer) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + `AiterScaledMMLinearKernel` implements a fused version of + `output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)` + where scale_a * a and scale_b * b are implemented using numpy-style + broadcasting. + Currently only support per-tensor-per-tensor GEMM + and per-token-per-channel GEMM through AITER + w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support + ATIER block scaled GEMM and mix-precision GEMM. + """ + w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) + + # ops.scaled_int8_quant supports both dynamic and static quant: + # * dynamic, i_s is None and x_s computed from x. + # * static, i_s is scalar and x_s is i_s. + symmetric = azp_adj is None + assert symmetric, ("AiterScaledMMLinearKernel only supports" + " symmetric quantization.") + x_q, x_s, x_zp = ops.scaled_int8_quant(x, + i_s, + i_zp, + symmetric=symmetric) + + assert x_zp is None, ("AiterScaledMMLinearKernel only supports" + " symmetric quantization.") + out_dtype = x.dtype + + assert (w_q.shape[0] % 16 == 0 and w_q.shape[1] % 16 == 0) + assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) + assert bias is None or bias.shape[0] == w_q.shape[ + 1] and bias.dtype == out_dtype + + m = x_q.shape[0] # a + n = w_q.shape[1] # b + + per_tensor_scale_a = (x_s.numel() == 1) + per_tensor_scale_b = (w_s.numel() == 1) + per_token_scale_a = (x_s.numel() == m) + per_channel_scale_b = (w_s.numel() == n) + + # @TODO: + # Maybe broadcast the per-tensor-scale into per-channel-scale + # if one of the scale is a per-channel-scale. + # For now, it only supports: + # - per-tensor-per-tensor a8w8 scaled GEMM, and + # - per-token-per-channel a8w8 scaled GEMM + assert ((per_tensor_scale_a and per_tensor_scale_b) + or (per_token_scale_a and per_channel_scale_b)), ( + "Currently only support per-tensor-per-tensor GEMM " + + " and per-token-per-channel GEMM through AITER" + " w8a8 scaled gemm. `AiterScaledMMLinearKernel` " + + "does not support AITER block scaled GEMM.") + + from aiter import gemm_a8w8_CK + + # gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects + # a to be [M, K] + # b to be [N, K] + # CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format + return gemm_a8w8_CK(x_q, w_q.t(), x_s, w_s, bias).to(out_dtype)