From e85829450d8016309d71de9f347e2147ee03400a Mon Sep 17 00:00:00 2001 From: Charlie Fu Date: Mon, 31 Mar 2025 06:42:18 -0500 Subject: [PATCH] [Feature][ROCm]Enable fusion pass for torch.compile on ROCm (#15050) Signed-off-by: charlifu --- csrc/quantization/fp8/common.cu | 7 +-- csrc/quantization/fp8/common.cuh | 41 ++----------- ...fused_layernorm_dynamic_per_token_quant.cu | 28 ++++----- .../fused_kernels/layernorm_utils.cuh | 13 ++-- .../fused_kernels/quant_conversions.cuh | 4 +- csrc/quantization/utils.cuh | 59 +++++++++++++++++++ tests/compile/test_fusion.py | 8 ++- vllm/compilation/fusion.py | 4 +- 8 files changed, 92 insertions(+), 72 deletions(-) create mode 100644 csrc/quantization/utils.cuh diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index 8f9aa21a..eceb3a8e 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -30,9 +30,6 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel( fp8_type* __restrict__ out, float* __restrict__ scale, scalar_t const* __restrict__ input, float const* __restrict__ scale_ub, const int hidden_size) { - float const min_scaling_factor = - 1.0f / (fp8_e4m3_adjusted_max_v * 512.f); - int const tid = threadIdx.x; int const token_idx = blockIdx.x; @@ -67,8 +64,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel( token_scale = block_absmax_val_maybe; } // token scale computation - token_scale = max(token_scale / fp8_e4m3_adjusted_max_v, - min_scaling_factor); + token_scale = max(token_scale / quant_type_max_v, + min_scaling_factor::val()); scale[token_idx] = token_scale; } __syncthreads(); diff --git a/csrc/quantization/fp8/common.cuh b/csrc/quantization/fp8/common.cuh index d331c63a..def8b31b 100644 --- a/csrc/quantization/fp8/common.cuh +++ b/csrc/quantization/fp8/common.cuh @@ -1,20 +1,12 @@ #pragma once #include "quantization/vectorization.cuh" +#include "quantization/utils.cuh" #include -#include -#ifndef USE_ROCM - #include - #define MAYBE_HOST_DEVICE C10_HOST_DEVICE -#else - #include - #include - #include +#ifdef USE_ROCM #include "amd/quant_utils.cuh" - // ROCm doesn't seem to need C10_HOST_DEVICE for static constexpr - #define MAYBE_HOST_DEVICE #endif // Determines the preferred FP8 type for the current platform. @@ -31,29 +23,6 @@ static bool is_fp8_ocp() { #endif } -template -struct fp8_e4m3_adjusted_max; - -template <> -struct fp8_e4m3_adjusted_max { - static constexpr c10::Float8_e4m3fn val() { - return std::numeric_limits::max(); - } -}; - -// Using the default max value from pytorch (240.0 0x7F) will cause accuracy -// issues when running dynamic quantization. Here use 224.0 0x7E for rocm. -template <> -struct fp8_e4m3_adjusted_max { - static constexpr c10::Float8_e4m3fnuz val() { - return c10::Float8_e4m3fnuz(0x7E, c10::Float8_e4m3fnuz::from_bits()); - } -}; - -template -MAYBE_HOST_DEVICE static constexpr T fp8_e4m3_adjusted_max_v = - fp8_e4m3_adjusted_max::val(); - namespace vllm { __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { @@ -76,8 +45,8 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val, x = val / scale; } - float r = fmax(-fp8_e4m3_adjusted_max_v, - fmin(x, fp8_e4m3_adjusted_max_v)); + float r = + fmax(-quant_type_max_v, fmin(x, quant_type_max_v)); #ifndef USE_ROCM return static_cast(r); #else @@ -123,7 +92,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale, // Finally, since cache[0] contains the maximum for this thread block, // atomically write the max to the target location if (threadIdx.x == 0) { - atomicMaxFloat(scale, cache[0] / fp8_e4m3_adjusted_max_v); + atomicMaxFloat(scale, cache[0] / quant_type_max_v); } } diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index 1be89c50..2b6ab7fc 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -14,8 +14,7 @@ __device__ void rms_norm_dynamic_per_token_quant_vec( float* __restrict__ scales, // [num_tokens] scalar_t const* __restrict__ input, // [..., hidden_size] scalar_t const* __restrict__ weight, // [hidden_size] - float const* scale_ub, float const var_epsilon, - float const min_scaling_factor, int32_t const hidden_size, + float const* scale_ub, float const var_epsilon, int32_t const hidden_size, scalar_t* __restrict__ residual = nullptr) { float rms = 0.0f; float token_scale = 0.0f; @@ -27,8 +26,8 @@ __device__ void rms_norm_dynamic_per_token_quant_vec( // Compute scale vllm::vectorized::compute_dynamic_per_token_scales( - &token_scale, scales, input, weight, rms, scale_ub, min_scaling_factor, - hidden_size, residual); + &token_scale, scales, input, weight, rms, scale_ub, hidden_size, + residual); // RMS Norm + Quant if constexpr (std::is_same_v) { @@ -50,8 +49,7 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel( float* __restrict__ scales, // [num_tokens] scalar_t const* __restrict__ input, // [..., hidden_size] scalar_t const* __restrict__ weight, // [hidden_size] - float const* scale_ub, float const var_epsilon, - float const min_scaling_factor, int32_t const hidden_size, + float const* scale_ub, float const var_epsilon, int32_t const hidden_size, scalar_t* __restrict__ residual = nullptr) { // For vectorization, token_input and token_output pointers need to be // aligned at 8-byte and 4-byte addresses respectively. @@ -60,8 +58,8 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel( if (can_vectorize) { return rms_norm_dynamic_per_token_quant_vec( - out, scales, input, weight, scale_ub, var_epsilon, min_scaling_factor, - hidden_size, residual); + out, scales, input, weight, scale_ub, var_epsilon, hidden_size, + residual); } float rms = 0.0f; @@ -72,8 +70,8 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel( var_epsilon, residual); // Compute Scale vllm::compute_dynamic_per_token_scales( - &token_scale, scales, input, weight, rms, scale_ub, min_scaling_factor, - hidden_size, residual); + &token_scale, scales, input, weight, rms, scale_ub, hidden_size, + residual); // RMS Norm + Quant if constexpr (std::is_same_v) { @@ -105,11 +103,6 @@ void rms_norm_dynamic_per_token_quant_dispatch( const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - const float min_scaling_factor = - out.dtype() == torch::kInt8 - ? std::numeric_limits::epsilon() - : 1.0f / (std::numeric_limits::max() * 512.f); - if (residual.has_value()) { VLLM_DISPATCH_QUANT_TYPES( out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] { @@ -119,8 +112,7 @@ void rms_norm_dynamic_per_token_quant_dispatch( out.data_ptr(), scales.data_ptr(), input.data_ptr(), weight.data_ptr(), scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, - var_epsilon, min_scaling_factor, hidden_size, - residual->data_ptr()); + var_epsilon, hidden_size, residual->data_ptr()); }); } else { @@ -132,7 +124,7 @@ void rms_norm_dynamic_per_token_quant_dispatch( out.data_ptr(), scales.data_ptr(), input.data_ptr(), weight.data_ptr(), scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, - var_epsilon, min_scaling_factor, hidden_size, nullptr); + var_epsilon, hidden_size, nullptr); }); } } diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh index b5cea98f..e6d23cd2 100644 --- a/csrc/quantization/fused_kernels/layernorm_utils.cuh +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -5,6 +5,7 @@ */ #include "quantization/vectorization.cuh" +#include "quantization/utils.cuh" #include "quant_conversions.cuh" #ifndef USE_ROCM @@ -51,11 +52,11 @@ __device__ void compute_dynamic_per_token_scales( float* __restrict__ token_scale, float* __restrict__ all_token_scales, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, float const rms, float const* __restrict__ scale_ub, - float const min_scaling_factor, int32_t const hidden_size, + int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr) { int64_t const token_offset = blockIdx.x * static_cast(hidden_size); ; - constexpr scalar_out_t qmax{std::numeric_limits::max()}; + constexpr scalar_out_t qmax{quant_type_max_v}; float block_absmax_val_maybe = 0.0f; for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) { @@ -83,7 +84,7 @@ __device__ void compute_dynamic_per_token_scales( scale = block_absmax_val_maybe; } // token scale computation - scale = max(scale / qmax, min_scaling_factor); + scale = max(scale / qmax, min_scaling_factor::val()); s_token_scale = scale; // Shared memory store all_token_scales[blockIdx.x] = scale; // Global output store } @@ -184,7 +185,7 @@ __device__ void compute_dynamic_per_token_scales( float* __restrict__ token_scale, float* __restrict__ all_token_scales, scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, float const rms, float const* __restrict__ scale_ub, - float const min_scaling_factor, int32_t const hidden_size, + int32_t const hidden_size, scalar_t const* __restrict__ residual = nullptr) { int64_t const token_offset = blockIdx.x * static_cast(hidden_size); ; @@ -200,7 +201,7 @@ __device__ void compute_dynamic_per_token_scales( reinterpret_cast const*>(&residual[token_offset]); } - constexpr scalar_out_t qmax{std::numeric_limits::max()}; + constexpr scalar_out_t qmax{quant_type_max_v}; int32_t const num_vec_elems = hidden_size >> 2; float block_absmax_val_maybe = 0.0f; @@ -248,7 +249,7 @@ __device__ void compute_dynamic_per_token_scales( scale = block_absmax_val_maybe; } // token scale computation - scale = max(scale / qmax, min_scaling_factor); + scale = max(scale / qmax, min_scaling_factor::val()); s_token_scale = scale; // shared memory store all_token_scales[blockIdx.x] = scale; // global output store } diff --git a/csrc/quantization/fused_kernels/quant_conversions.cuh b/csrc/quantization/fused_kernels/quant_conversions.cuh index 9ac7b188..7c10aaa8 100644 --- a/csrc/quantization/fused_kernels/quant_conversions.cuh +++ b/csrc/quantization/fused_kernels/quant_conversions.cuh @@ -33,8 +33,8 @@ static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) { template static __device__ __forceinline__ fp8_type float_to_fp8(float const x) { - float const r = fmax(-fp8_e4m3_adjusted_max_v, - fmin(x, fp8_e4m3_adjusted_max_v)); + float const r = + fmax(-quant_type_max_v, fmin(x, quant_type_max_v)); return static_cast(r); } diff --git a/csrc/quantization/utils.cuh b/csrc/quantization/utils.cuh new file mode 100644 index 00000000..73055a15 --- /dev/null +++ b/csrc/quantization/utils.cuh @@ -0,0 +1,59 @@ +#pragma once + +/** + * Quantization utilities including: + * Adjusted maximum values for qtypes. + * Minimum scaling factors for qtypes. + */ + +#include +#include + +#ifndef USE_ROCM + #include + #define MAYBE_HOST_DEVICE C10_HOST_DEVICE +#else + #include + #include + #include + // ROCm doesn't seem to need C10_HOST_DEVICE for static constexpr + #define MAYBE_HOST_DEVICE +#endif + +template || + std::is_same_v || + std::is_same_v>> +struct quant_type_max { + static constexpr T val() { return std::numeric_limits::max(); } +}; + +// Using the default max value from pytorch (240.0 0x7F) will cause accuracy +// issues when running dynamic quantization. Here use 224.0 0x7E for rocm. +template <> +struct quant_type_max { + static constexpr c10::Float8_e4m3fnuz val() { + return c10::Float8_e4m3fnuz(0x7E, c10::Float8_e4m3fnuz::from_bits()); + } +}; + +template +MAYBE_HOST_DEVICE static constexpr T quant_type_max_v = + quant_type_max::val(); + +template || + std::is_same_v || + std::is_same_v>> +struct min_scaling_factor { + C10_DEVICE C10_ALWAYS_INLINE static float val() { + return 1.0f / (quant_type_max_v * 512.0f); + } +}; + +template <> +struct min_scaling_factor { + C10_DEVICE C10_ALWAYS_INLINE static float val() { + return std::numeric_limits::epsilon(); + } +}; \ No newline at end of file diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index aaf02778..a1adf708 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -2,7 +2,6 @@ import pytest import torch -from compressed_tensors.quantization import FP8_DTYPE import vllm.envs as envs import vllm.plugins @@ -14,9 +13,12 @@ from vllm.config import CompilationConfig, CompilationLevel, VllmConfig from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity) +from vllm.platforms import current_platform from .backend import TestBackend +FP8_DTYPE = current_platform.fp8_dtype() + class TestModel(torch.nn.Module): @@ -59,8 +61,8 @@ class TestModel(torch.nn.Module): @pytest.mark.parametrize("static", [True, False]) @pytest.mark.parametrize("cutlass_fp8_enabled", [True, False] if CUTLASS_FP8_SUPPORTED else [False]) -@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", - reason="Only test on CUDA") +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], + reason="Only test on CUDA and ROCm") def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, cutlass_fp8_enabled): torch.set_default_device("cuda") diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 0c3d8697..b46f5f52 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -4,8 +4,6 @@ from typing import Callable, Dict, List, NamedTuple, Optional, Tuple import torch import torch._inductor.pattern_matcher as pm -# TODO(luka) use vllm.utils once #10836 landed -from compressed_tensors.quantization import FP8_DTYPE from torch import fx from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import PatternMatcherPass @@ -13,12 +11,14 @@ from torch._ops import OpOverload from vllm.config import CompilationConfig from vllm.logger import init_logger +from vllm.platforms import current_platform from .fx_utils import find_getitem_maybe from .multi_output_match import MultiOutputMatch from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) +FP8_DTYPE = current_platform.fp8_dtype() def empty_bf16(*args, **kwargs):