diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index 6dae32b2..3f77c76a 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -9,6 +9,18 @@ #include "../../reduction_utils.cuh" +#ifndef USE_ROCM +using FP8_TYPE = c10::Float8_e4m3fn; +C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = + std::numeric_limits::max(); +#else + #include "amd/hip_float8.h" +using FP8_TYPE = c10::Float8_e4m3fnuz; +// Using the default max value from pytorch (240.0) will cause accuracy +// issue when running dynamic quantization. Here use 224.0f for rocm. +constexpr auto FP8_E4M3_MAX = 224.0f; +#endif + namespace vllm { __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { @@ -21,11 +33,9 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { return old; } -#define FP8_E4M3_MAX std::numeric_limits::max() - template -__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion( - float const val, float const scale) { +__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val, + float const scale) { float x = 0.0f; if constexpr (is_scale_inverted) { x = val * scale; @@ -34,7 +44,13 @@ __device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion( } float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); +#ifndef USE_ROCM return static_cast(r); +#else + // Use hardware cvt instruction for fp8 on rocm + return c10::Float8_e4m3fnuz(hip_fp8(r).data, + c10::Float8_e4m3fnuz::from_bits()); +#endif } // Compute the absolute maximum m of the input tensor and store @@ -74,8 +90,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale, // Finally, since cache[0] contains the maximum for this thread block, // atomically write the max to the target location if (threadIdx.x == 0) { - atomicMaxFloat(scale, - cache[0] / std::numeric_limits::max()); + atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX); } } @@ -88,10 +103,10 @@ struct __align__(8) vec4_t { }; typedef struct __align__(4) { - c10::Float8_e4m3fn x; - c10::Float8_e4m3fn y; - c10::Float8_e4m3fn z; - c10::Float8_e4m3fn w; + FP8_TYPE x; + FP8_TYPE y; + FP8_TYPE z; + FP8_TYPE w; } float8x4_t; @@ -124,7 +139,7 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input, } template -__device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out, +__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out, scalar_t const* __restrict__ input, float const scale, int64_t const num_elems, @@ -160,7 +175,7 @@ __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out, } template -__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out, +__global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out, const scalar_t* __restrict__ input, const float* __restrict__ scale, int64_t num_elems) { @@ -175,7 +190,7 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out, template __global__ void dynamic_per_token_scaled_fp8_quant_kernel( - c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale, + FP8_TYPE* __restrict__ out, float* __restrict__ scale, scalar_t const* __restrict__ input, float const* __restrict__ scale_ub, const int hidden_size) { float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f); @@ -184,7 +199,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel( int const token_idx = blockIdx.x; scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size]; - c10::Float8_e4m3fn* __restrict__ token_output = &out[token_idx * hidden_size]; + FP8_TYPE* __restrict__ token_output = &out[token_idx * hidden_size]; // For vectorization, token_input and token_output pointers need to be // aligned at 8-byte and 4-byte addresses respectively. @@ -241,7 +256,7 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "scaled_fp8_quant_kernel", [&] { vllm::scaled_fp8_quant_kernel<<>>( - out.data_ptr(), input.data_ptr(), + out.data_ptr(), input.data_ptr(), scale.data_ptr(), num_elems); }); } @@ -261,7 +276,7 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] vllm::segmented_max_reduction<<>>( scale.data_ptr(), input.data_ptr(), num_elems); vllm::scaled_fp8_quant_kernel<<>>( - out.data_ptr(), input.data_ptr(), + out.data_ptr(), input.data_ptr(), scale.data_ptr(), num_elems); }); } @@ -284,7 +299,7 @@ void dynamic_per_token_scaled_fp8_quant( input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] { vllm::dynamic_per_token_scaled_fp8_quant_kernel <<>>( - out.data_ptr(), scales.data_ptr(), + out.data_ptr(), scales.data_ptr(), input.data_ptr(), scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, hidden_size); diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index 9bdace67..8f6a54ff 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -2,6 +2,13 @@ from typing import Optional, Tuple, Union import torch +from vllm.utils import is_hip + +# Using the default value (240.0) from pytorch will cause accuracy +# issue on dynamic quantization models. Here use 224.0 for rocm. +ROCM_FP8_MAX = 224.0 +FP8_DTYPE = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn + def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: return torch.as_tensor(x, dtype=torch.float32, device='cuda') @@ -11,13 +18,15 @@ def ref_dynamic_per_token_quant(x: torch.tensor, scale_ub: Optional[torch.tensor] = None) \ -> Tuple[torch.tensor, torch.tensor]: - assert quant_dtype in [torch.int8, torch.float8_e4m3fn] + assert quant_dtype in [torch.int8, FP8_DTYPE] if scale_ub is not None: - assert quant_dtype == torch.float8_e4m3fn + assert quant_dtype == FP8_DTYPE qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \ else torch.finfo(quant_dtype) - qtype_max = as_float32_tensor(qtype_traits.max) + qtype_traits_max = ROCM_FP8_MAX if is_hip() else qtype_traits.max + qtype_traits_min = -ROCM_FP8_MAX if is_hip() else qtype_traits.min + qtype_max = as_float32_tensor(qtype_traits_max) s_1 = as_float32_tensor(1.0) s_512 = as_float32_tensor(512.0) @@ -37,15 +46,15 @@ def ref_dynamic_per_token_quant(x: torch.tensor, iscales = as_float32_tensor(s_1 / scales) torch_out = as_float32_tensor(x) * iscales torch_out = torch_out.round() - torch_out = torch_out.clamp(qtype_traits.min, - qtype_traits.max).to(quant_dtype) + torch_out = torch_out.clamp(qtype_traits_min, + qtype_traits_max).to(quant_dtype) else: - assert quant_dtype == torch.float8_e4m3fn + assert quant_dtype == FP8_DTYPE min_scaling_factor = s_1 / (qtype_max * s_512) scales = scales.clamp(min=min_scaling_factor) torch_out = as_float32_tensor(x) / scales - torch_out = torch_out.clamp(qtype_traits.min, - qtype_traits.max).to(quant_dtype) + torch_out = torch_out.clamp(qtype_traits_min, + qtype_traits_max).to(quant_dtype) return torch_out, scales @@ -56,8 +65,10 @@ def ref_dynamic_per_token_quant(x: torch.tensor, def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ -> Tuple[torch.tensor, torch.tensor]: - fp8_traits = torch.finfo(torch.float8_e4m3fn) - fp8_max = as_float32_tensor(fp8_traits.max) + fp8_traits = torch.finfo(FP8_DTYPE) + fp8_traits_max = ROCM_FP8_MAX if is_hip() else fp8_traits.max + fp8_traits_min = -ROCM_FP8_MAX if is_hip() else fp8_traits.min + fp8_max = as_float32_tensor(fp8_traits_max) one = as_float32_tensor(1.0) # For fp8, in order to match the cuda kernel output, we have to do exactly @@ -68,5 +79,5 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ ref_scale = x_max / fp8_max ref_iscale = one / ref_scale ref_out = (as_float32_tensor(x) * ref_iscale).clamp( - fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn) + fp8_traits_min, fp8_traits_max).to(FP8_DTYPE) return ref_out, ref_scale.view((1, )) diff --git a/tests/kernels/test_fp8_quant.py b/tests/kernels/test_fp8_quant.py index 71a92cbc..bae9b392 100644 --- a/tests/kernels/test_fp8_quant.py +++ b/tests/kernels/test_fp8_quant.py @@ -2,7 +2,8 @@ import pytest import torch import vllm._custom_ops as ops -from tests.kernels.quant_utils import (ref_dynamic_per_tensor_fp8_quant, +from tests.kernels.quant_utils import (FP8_DTYPE, + ref_dynamic_per_tensor_fp8_quant, ref_dynamic_per_token_quant) DTYPES = [torch.half, torch.bfloat16, torch.float] @@ -31,8 +32,7 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, scale_ub = torch.mean(x).to(dtype=torch.float32, device='cuda') \ if scale_ub else None - ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn, - scale_ub) + ref_out, ref_scales = ref_dynamic_per_token_quant(x, FP8_DTYPE, scale_ub) ops_out, ops_scales = ops.scaled_fp8_quant(x, scale_ub=scale_ub, use_per_token_if_dynamic=True) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index b6329859..59fe5329 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -369,9 +369,12 @@ def scaled_fp8_quant( # This code assumes batch_dim and num_tokens are flattened assert (input.ndim == 2) shape: Union[Tuple[int, int], torch.Size] = input.shape + # For rocm, the output fp8 dtype is torch.float_e3m3fnuz + out_dtype: torch.dtype = torch.float8_e4m3fnuz if vllm.utils.is_hip() \ + else torch.float8_e4m3fn if num_token_padding: shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=torch.float8_e4m3fn) + output = torch.empty(shape, device=input.device, dtype=out_dtype) if scale is None: if use_per_token_if_dynamic: diff --git a/vllm/config.py b/vllm/config.py index 19cd4d8b..085daf1b 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -240,7 +240,7 @@ class ModelConfig: def _verify_quantization(self) -> None: supported_quantization = [*QUANTIZATION_METHODS] - rocm_supported_quantization = ["gptq", "squeezellm"] + rocm_supported_quantization = ["gptq", "squeezellm", "fp8"] optimized_quantization_methods = [ "fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin", "fbgemm_fp8", "compressed_tensors", "compressed-tensors" diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 8b8cf41c..77f12138 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -20,10 +20,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, apply_fp8_linear, convert_to_channelwise, create_per_tensor_scale_param, cutlass_fp8_supported, - per_tensor_dequantize, requantize_with_max_scale) + normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, + requantize_with_max_scale) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.utils import print_warning_once +from vllm.utils import is_hip, print_warning_once ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -120,6 +121,9 @@ class Fp8LinearMethod(LinearMethodBase): capability = current_platform.get_device_capability() capability = capability[0] * 10 + capability[1] self.use_marlin = capability < 89 or envs.VLLM_TEST_FORCE_FP8_MARLIN + # Disable marlin for rocm + if is_hip(): + self.use_marlin = False def create_weights( self, @@ -168,6 +172,8 @@ class Fp8LinearMethod(LinearMethodBase): scale = create_per_tensor_scale_param(output_partition_sizes, **extra_weight_attrs) layer.register_parameter("input_scale", scale) + else: + layer.register_parameter("input_scale", None) def process_weights_after_loading(self, layer: Module) -> None: # If checkpoint not serialized fp8, quantize the weights. @@ -202,9 +208,23 @@ class Fp8LinearMethod(LinearMethodBase): # requantize the logical shards as a single weight. else: # Dequant -> Quant with max scale so we can run per tensor. + weight = layer.weight + weight_scale = layer.weight_scale + + # If rocm, use float8_e4m3fnuz. + if is_hip(): + weight, weight_scale, input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=weight_scale, + input_scale=layer.input_scale) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, + requires_grad=False) + weight_scale, weight = requantize_with_max_scale( - weight=layer.weight, - weight_scale=layer.weight_scale, + weight=weight, + weight_scale=weight_scale, logical_widths=layer.logical_widths, ) @@ -214,8 +234,6 @@ class Fp8LinearMethod(LinearMethodBase): if self.quant_config.activation_scheme == "static": layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) - else: - layer.input_scale = None if self.use_marlin: prepare_fp8_layer_for_marlin(layer) @@ -346,10 +364,12 @@ class Fp8MoEMethod(FusedMoEMethodBase): # If checkpoint is fp16, quantize in place. if not self.quant_config.is_checkpoint_fp8_serialized: + # If rocm, use float8_e4m3fnuz as dtype + fp8_dtype = torch.float8_e4m3fnuz \ + if is_hip() else torch.float8_e4m3fn w13_weight = torch.empty_like(layer.w13_weight.data, - dtype=torch.float8_e4m3fn) - w2_weight = torch.empty_like(layer.w2_weight.data, - dtype=torch.float8_e4m3fn) + dtype=fp8_dtype) + w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) # Re-initialize w13_scale because we directly quantize # merged w13 weights and generate a single scaling factor. @@ -393,6 +413,32 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.w13_input_scale.max(), requires_grad=False) layer.w2_input_scale = torch.nn.Parameter( layer.w2_input_scale.max(), requires_grad=False) + # If rocm, normalize the weights and scales to e4m3fnuz + if is_hip(): + # Normalize the weights and scales + w13_weight, w13_weight_scale, w13_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w13_weight, layer.w13_weight_scale, + layer.w13_input_scale) + w2_weight, w2_weight_scale, w2_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale, + layer.w2_input_scale) + # Reset the parameter + layer.w13_weight = torch.nn.Parameter(w13_weight, + requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + w13_weight_scale, requires_grad=False) + if w13_input_scale is not None: + layer.w13_input_scale = torch.nn.Parameter( + w13_input_scale, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, + requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, + requires_grad=False) + if w2_input_scale is not None: + layer.w2_input_scale = torch.nn.Parameter( + w2_input_scale, requires_grad=False) # Fp8 moe kernel needs single weight scale for w13 per expert. # We take the max then dequant and requant each expert. diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index dbe86902..6cc1c65d 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -6,9 +6,19 @@ from torch.nn import Parameter from vllm import _custom_ops as ops from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform +from vllm.utils import is_hip + +# scaled_mm in pytorch on rocm has a bug that requires always +# providing scaling factor for result. This value is created +# as global value to avoid multiple tensor allocations, and +# can be removed once pytorch fixes the bug. +TORCH_SCALED_MM_SCALE_RESULT = torch.ones(1).cuda() if is_hip() else None def cutlass_fp8_supported() -> bool: + # cutlass is not supported on Rocm + if is_hip(): + return False capability = current_platform.get_device_capability() capability = capability[0] * 10 + capability[1] @@ -147,13 +157,19 @@ def apply_fp8_linear( if per_tensor_weights and per_tensor_activations: # Fused GEMM_DQ - output, _ = torch._scaled_mm(qinput, - weight, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias) - return torch.narrow(output, 0, 0, input.shape[0]) + output = torch._scaled_mm( + qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + scale_result=TORCH_SCALED_MM_SCALE_RESULT, + bias=bias) + # Since in torch 2.5, scaled_mm only returns single value + # This should be removed when vllm-nvidia also moves to 2.5 + if is_hip(): + return torch.narrow(output, 0, 0, input.shape[0]) + return torch.narrow(output[0], 0, 0, input.shape[0]) else: # Fallback for channelwise case, where we use unfused DQ @@ -207,3 +223,27 @@ def apply_int8_linear( scale_b=weight_scale, out_dtype=input.dtype, bias=bias) + + +def normalize_e4m3fn_to_e4m3fnuz( + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + assert weight.dtype == torch.float8_e4m3fn + # The bits pattern 10000000(-128) represents zero in e4m3fn + # but NaN in e4m3fnuz. So here we set it to 0. + # https://onnx.ai/onnx/technical/float8.html + weight_as_int8 = weight.view(torch.int8) + ROCM_FP8_NAN_AS_INT = -128 + weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0 + weight = weight_as_int8.view(torch.float8_e4m3fnuz) + + # For the same bits representation, e4m3fnuz value is half of + # the e4m3fn value, so we should double the scaling factor to + # get the same dequantized value. + # https://onnx.ai/onnx/technical/float8.html + weight_scale = weight_scale * 2.0 + if input_scale is not None: + input_scale = input_scale * 2.0 + return weight, weight_scale, input_scale