[Feature][Hardware][Amd] Add fp8 Linear Layer for Rocm (#7210)

This commit is contained in:
Charlie Fu 2024-08-16 12:06:30 -05:00 committed by GitHub
parent ec724a725e
commit e837b624f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 164 additions and 49 deletions

View File

@ -9,6 +9,18 @@
#include "../../reduction_utils.cuh" #include "../../reduction_utils.cuh"
#ifndef USE_ROCM
using FP8_TYPE = c10::Float8_e4m3fn;
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
std::numeric_limits<FP8_TYPE>::max();
#else
#include "amd/hip_float8.h"
using FP8_TYPE = c10::Float8_e4m3fnuz;
// Using the default max value from pytorch (240.0) will cause accuracy
// issue when running dynamic quantization. Here use 224.0f for rocm.
constexpr auto FP8_E4M3_MAX = 224.0f;
#endif
namespace vllm { namespace vllm {
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
@ -21,11 +33,9 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
return old; return old;
} }
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
template <bool is_scale_inverted> template <bool is_scale_inverted>
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion( __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
float const val, float const scale) { float const scale) {
float x = 0.0f; float x = 0.0f;
if constexpr (is_scale_inverted) { if constexpr (is_scale_inverted) {
x = val * scale; x = val * scale;
@ -34,7 +44,13 @@ __device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
} }
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
#ifndef USE_ROCM
return static_cast<c10::Float8_e4m3fn>(r); return static_cast<c10::Float8_e4m3fn>(r);
#else
// Use hardware cvt instruction for fp8 on rocm
return c10::Float8_e4m3fnuz(hip_fp8(r).data,
c10::Float8_e4m3fnuz::from_bits());
#endif
} }
// Compute the absolute maximum m of the input tensor and store // Compute the absolute maximum m of the input tensor and store
@ -74,8 +90,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
// Finally, since cache[0] contains the maximum for this thread block, // Finally, since cache[0] contains the maximum for this thread block,
// atomically write the max to the target location // atomically write the max to the target location
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
atomicMaxFloat(scale, atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX);
cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
} }
} }
@ -88,10 +103,10 @@ struct __align__(8) vec4_t {
}; };
typedef struct __align__(4) { typedef struct __align__(4) {
c10::Float8_e4m3fn x; FP8_TYPE x;
c10::Float8_e4m3fn y; FP8_TYPE y;
c10::Float8_e4m3fn z; FP8_TYPE z;
c10::Float8_e4m3fn w; FP8_TYPE w;
} }
float8x4_t; float8x4_t;
@ -124,7 +139,7 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
} }
template <typename scalar_t, bool is_scale_inverted> template <typename scalar_t, bool is_scale_inverted>
__device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out, __device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
scalar_t const* __restrict__ input, scalar_t const* __restrict__ input,
float const scale, float const scale,
int64_t const num_elems, int64_t const num_elems,
@ -160,7 +175,7 @@ __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
} }
template <typename scalar_t> template <typename scalar_t>
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out, __global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
const scalar_t* __restrict__ input, const scalar_t* __restrict__ input,
const float* __restrict__ scale, const float* __restrict__ scale,
int64_t num_elems) { int64_t num_elems) {
@ -175,7 +190,7 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
template <typename scalar_t> template <typename scalar_t>
__global__ void dynamic_per_token_scaled_fp8_quant_kernel( __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale, FP8_TYPE* __restrict__ out, float* __restrict__ scale,
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub, scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
const int hidden_size) { const int hidden_size) {
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f); float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
@ -184,7 +199,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
int const token_idx = blockIdx.x; int const token_idx = blockIdx.x;
scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size]; scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size];
c10::Float8_e4m3fn* __restrict__ token_output = &out[token_idx * hidden_size]; FP8_TYPE* __restrict__ token_output = &out[token_idx * hidden_size];
// For vectorization, token_input and token_output pointers need to be // For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively. // aligned at 8-byte and 4-byte addresses respectively.
@ -241,7 +256,7 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "scaled_fp8_quant_kernel", [&] { input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>( vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(), out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
scale.data_ptr<float>(), num_elems); scale.data_ptr<float>(), num_elems);
}); });
} }
@ -261,7 +276,7 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>( vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems); scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems);
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>( vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(), out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
scale.data_ptr<float>(), num_elems); scale.data_ptr<float>(), num_elems);
}); });
} }
@ -284,7 +299,7 @@ void dynamic_per_token_scaled_fp8_quant(
input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] { input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t> vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
out.data_ptr<c10::Float8_e4m3fn>(), scales.data_ptr<float>(), out.data_ptr<FP8_TYPE>(), scales.data_ptr<float>(),
input.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr, scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
hidden_size); hidden_size);

View File

@ -2,6 +2,13 @@ from typing import Optional, Tuple, Union
import torch import torch
from vllm.utils import is_hip
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for rocm.
ROCM_FP8_MAX = 224.0
FP8_DTYPE = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
return torch.as_tensor(x, dtype=torch.float32, device='cuda') return torch.as_tensor(x, dtype=torch.float32, device='cuda')
@ -11,13 +18,15 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
scale_ub: Optional[torch.tensor] = None) \ scale_ub: Optional[torch.tensor] = None) \
-> Tuple[torch.tensor, torch.tensor]: -> Tuple[torch.tensor, torch.tensor]:
assert quant_dtype in [torch.int8, torch.float8_e4m3fn] assert quant_dtype in [torch.int8, FP8_DTYPE]
if scale_ub is not None: if scale_ub is not None:
assert quant_dtype == torch.float8_e4m3fn assert quant_dtype == FP8_DTYPE
qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \ qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
else torch.finfo(quant_dtype) else torch.finfo(quant_dtype)
qtype_max = as_float32_tensor(qtype_traits.max) qtype_traits_max = ROCM_FP8_MAX if is_hip() else qtype_traits.max
qtype_traits_min = -ROCM_FP8_MAX if is_hip() else qtype_traits.min
qtype_max = as_float32_tensor(qtype_traits_max)
s_1 = as_float32_tensor(1.0) s_1 = as_float32_tensor(1.0)
s_512 = as_float32_tensor(512.0) s_512 = as_float32_tensor(512.0)
@ -37,15 +46,15 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
iscales = as_float32_tensor(s_1 / scales) iscales = as_float32_tensor(s_1 / scales)
torch_out = as_float32_tensor(x) * iscales torch_out = as_float32_tensor(x) * iscales
torch_out = torch_out.round() torch_out = torch_out.round()
torch_out = torch_out.clamp(qtype_traits.min, torch_out = torch_out.clamp(qtype_traits_min,
qtype_traits.max).to(quant_dtype) qtype_traits_max).to(quant_dtype)
else: else:
assert quant_dtype == torch.float8_e4m3fn assert quant_dtype == FP8_DTYPE
min_scaling_factor = s_1 / (qtype_max * s_512) min_scaling_factor = s_1 / (qtype_max * s_512)
scales = scales.clamp(min=min_scaling_factor) scales = scales.clamp(min=min_scaling_factor)
torch_out = as_float32_tensor(x) / scales torch_out = as_float32_tensor(x) / scales
torch_out = torch_out.clamp(qtype_traits.min, torch_out = torch_out.clamp(qtype_traits_min,
qtype_traits.max).to(quant_dtype) qtype_traits_max).to(quant_dtype)
return torch_out, scales return torch_out, scales
@ -56,8 +65,10 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
-> Tuple[torch.tensor, torch.tensor]: -> Tuple[torch.tensor, torch.tensor]:
fp8_traits = torch.finfo(torch.float8_e4m3fn) fp8_traits = torch.finfo(FP8_DTYPE)
fp8_max = as_float32_tensor(fp8_traits.max) fp8_traits_max = ROCM_FP8_MAX if is_hip() else fp8_traits.max
fp8_traits_min = -ROCM_FP8_MAX if is_hip() else fp8_traits.min
fp8_max = as_float32_tensor(fp8_traits_max)
one = as_float32_tensor(1.0) one = as_float32_tensor(1.0)
# For fp8, in order to match the cuda kernel output, we have to do exactly # For fp8, in order to match the cuda kernel output, we have to do exactly
@ -68,5 +79,5 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
ref_scale = x_max / fp8_max ref_scale = x_max / fp8_max
ref_iscale = one / ref_scale ref_iscale = one / ref_scale
ref_out = (as_float32_tensor(x) * ref_iscale).clamp( ref_out = (as_float32_tensor(x) * ref_iscale).clamp(
fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn) fp8_traits_min, fp8_traits_max).to(FP8_DTYPE)
return ref_out, ref_scale.view((1, )) return ref_out, ref_scale.view((1, ))

View File

@ -2,7 +2,8 @@ import pytest
import torch import torch
import vllm._custom_ops as ops import vllm._custom_ops as ops
from tests.kernels.quant_utils import (ref_dynamic_per_tensor_fp8_quant, from tests.kernels.quant_utils import (FP8_DTYPE,
ref_dynamic_per_tensor_fp8_quant,
ref_dynamic_per_token_quant) ref_dynamic_per_token_quant)
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
@ -31,8 +32,7 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
scale_ub = torch.mean(x).to(dtype=torch.float32, device='cuda') \ scale_ub = torch.mean(x).to(dtype=torch.float32, device='cuda') \
if scale_ub else None if scale_ub else None
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn, ref_out, ref_scales = ref_dynamic_per_token_quant(x, FP8_DTYPE, scale_ub)
scale_ub)
ops_out, ops_scales = ops.scaled_fp8_quant(x, ops_out, ops_scales = ops.scaled_fp8_quant(x,
scale_ub=scale_ub, scale_ub=scale_ub,
use_per_token_if_dynamic=True) use_per_token_if_dynamic=True)

View File

@ -369,9 +369,12 @@ def scaled_fp8_quant(
# This code assumes batch_dim and num_tokens are flattened # This code assumes batch_dim and num_tokens are flattened
assert (input.ndim == 2) assert (input.ndim == 2)
shape: Union[Tuple[int, int], torch.Size] = input.shape shape: Union[Tuple[int, int], torch.Size] = input.shape
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype: torch.dtype = torch.float8_e4m3fnuz if vllm.utils.is_hip() \
else torch.float8_e4m3fn
if num_token_padding: if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1]) shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=torch.float8_e4m3fn) output = torch.empty(shape, device=input.device, dtype=out_dtype)
if scale is None: if scale is None:
if use_per_token_if_dynamic: if use_per_token_if_dynamic:

View File

@ -240,7 +240,7 @@ class ModelConfig:
def _verify_quantization(self) -> None: def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS] supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = ["gptq", "squeezellm"] rocm_supported_quantization = ["gptq", "squeezellm", "fp8"]
optimized_quantization_methods = [ optimized_quantization_methods = [
"fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin", "fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
"fbgemm_fp8", "compressed_tensors", "compressed-tensors" "fbgemm_fp8", "compressed_tensors", "compressed-tensors"

View File

@ -20,10 +20,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, apply_fp8_linear, convert_to_channelwise, all_close_1d, apply_fp8_linear, convert_to_channelwise,
create_per_tensor_scale_param, cutlass_fp8_supported, create_per_tensor_scale_param, cutlass_fp8_supported,
per_tensor_dequantize, requantize_with_max_scale) normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
requantize_with_max_scale)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import print_warning_once from vllm.utils import is_hip, print_warning_once
ACTIVATION_SCHEMES = ["static", "dynamic"] ACTIVATION_SCHEMES = ["static", "dynamic"]
@ -120,6 +121,9 @@ class Fp8LinearMethod(LinearMethodBase):
capability = current_platform.get_device_capability() capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1] capability = capability[0] * 10 + capability[1]
self.use_marlin = capability < 89 or envs.VLLM_TEST_FORCE_FP8_MARLIN self.use_marlin = capability < 89 or envs.VLLM_TEST_FORCE_FP8_MARLIN
# Disable marlin for rocm
if is_hip():
self.use_marlin = False
def create_weights( def create_weights(
self, self,
@ -168,6 +172,8 @@ class Fp8LinearMethod(LinearMethodBase):
scale = create_per_tensor_scale_param(output_partition_sizes, scale = create_per_tensor_scale_param(output_partition_sizes,
**extra_weight_attrs) **extra_weight_attrs)
layer.register_parameter("input_scale", scale) layer.register_parameter("input_scale", scale)
else:
layer.register_parameter("input_scale", None)
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
# If checkpoint not serialized fp8, quantize the weights. # If checkpoint not serialized fp8, quantize the weights.
@ -202,9 +208,23 @@ class Fp8LinearMethod(LinearMethodBase):
# requantize the logical shards as a single weight. # requantize the logical shards as a single weight.
else: else:
# Dequant -> Quant with max scale so we can run per tensor. # Dequant -> Quant with max scale so we can run per tensor.
weight = layer.weight
weight_scale = layer.weight_scale
# If rocm, use float8_e4m3fnuz.
if is_hip():
weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=weight_scale,
input_scale=layer.input_scale)
if input_scale is not None:
layer.input_scale = Parameter(input_scale,
requires_grad=False)
weight_scale, weight = requantize_with_max_scale( weight_scale, weight = requantize_with_max_scale(
weight=layer.weight, weight=weight,
weight_scale=layer.weight_scale, weight_scale=weight_scale,
logical_widths=layer.logical_widths, logical_widths=layer.logical_widths,
) )
@ -214,8 +234,6 @@ class Fp8LinearMethod(LinearMethodBase):
if self.quant_config.activation_scheme == "static": if self.quant_config.activation_scheme == "static":
layer.input_scale = Parameter(layer.input_scale.max(), layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False) requires_grad=False)
else:
layer.input_scale = None
if self.use_marlin: if self.use_marlin:
prepare_fp8_layer_for_marlin(layer) prepare_fp8_layer_for_marlin(layer)
@ -346,10 +364,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# If checkpoint is fp16, quantize in place. # If checkpoint is fp16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized: if not self.quant_config.is_checkpoint_fp8_serialized:
# If rocm, use float8_e4m3fnuz as dtype
fp8_dtype = torch.float8_e4m3fnuz \
if is_hip() else torch.float8_e4m3fn
w13_weight = torch.empty_like(layer.w13_weight.data, w13_weight = torch.empty_like(layer.w13_weight.data,
dtype=torch.float8_e4m3fn) dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
dtype=torch.float8_e4m3fn)
# Re-initialize w13_scale because we directly quantize # Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor. # merged w13 weights and generate a single scaling factor.
@ -393,6 +413,32 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w13_input_scale.max(), requires_grad=False) layer.w13_input_scale.max(), requires_grad=False)
layer.w2_input_scale = torch.nn.Parameter( layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False) layer.w2_input_scale.max(), requires_grad=False)
# If rocm, normalize the weights and scales to e4m3fnuz
if is_hip():
# Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
layer.w13_weight, layer.w13_weight_scale,
layer.w13_input_scale)
w2_weight, w2_weight_scale, w2_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
layer.w2_weight, layer.w2_weight_scale,
layer.w2_input_scale)
# Reset the parameter
layer.w13_weight = torch.nn.Parameter(w13_weight,
requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(
w13_weight_scale, requires_grad=False)
if w13_input_scale is not None:
layer.w13_input_scale = torch.nn.Parameter(
w13_input_scale, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight,
requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
requires_grad=False)
if w2_input_scale is not None:
layer.w2_input_scale = torch.nn.Parameter(
w2_input_scale, requires_grad=False)
# Fp8 moe kernel needs single weight scale for w13 per expert. # Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert. # We take the max then dequant and requant each expert.

View File

@ -6,9 +6,19 @@ from torch.nn import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import is_hip
# scaled_mm in pytorch on rocm has a bug that requires always
# providing scaling factor for result. This value is created
# as global value to avoid multiple tensor allocations, and
# can be removed once pytorch fixes the bug.
TORCH_SCALED_MM_SCALE_RESULT = torch.ones(1).cuda() if is_hip() else None
def cutlass_fp8_supported() -> bool: def cutlass_fp8_supported() -> bool:
# cutlass is not supported on Rocm
if is_hip():
return False
capability = current_platform.get_device_capability() capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1] capability = capability[0] * 10 + capability[1]
@ -147,13 +157,19 @@ def apply_fp8_linear(
if per_tensor_weights and per_tensor_activations: if per_tensor_weights and per_tensor_activations:
# Fused GEMM_DQ # Fused GEMM_DQ
output, _ = torch._scaled_mm(qinput, output = torch._scaled_mm(
weight, qinput,
out_dtype=input.dtype, weight,
scale_a=x_scale, out_dtype=input.dtype,
scale_b=weight_scale, scale_a=x_scale,
bias=bias) scale_b=weight_scale,
return torch.narrow(output, 0, 0, input.shape[0]) scale_result=TORCH_SCALED_MM_SCALE_RESULT,
bias=bias)
# Since in torch 2.5, scaled_mm only returns single value
# This should be removed when vllm-nvidia also moves to 2.5
if is_hip():
return torch.narrow(output, 0, 0, input.shape[0])
return torch.narrow(output[0], 0, 0, input.shape[0])
else: else:
# Fallback for channelwise case, where we use unfused DQ # Fallback for channelwise case, where we use unfused DQ
@ -207,3 +223,27 @@ def apply_int8_linear(
scale_b=weight_scale, scale_b=weight_scale,
out_dtype=input.dtype, out_dtype=input.dtype,
bias=bias) bias=bias)
def normalize_e4m3fn_to_e4m3fnuz(
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
assert weight.dtype == torch.float8_e4m3fn
# The bits pattern 10000000(-128) represents zero in e4m3fn
# but NaN in e4m3fnuz. So here we set it to 0.
# https://onnx.ai/onnx/technical/float8.html
weight_as_int8 = weight.view(torch.int8)
ROCM_FP8_NAN_AS_INT = -128
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
weight = weight_as_int8.view(torch.float8_e4m3fnuz)
# For the same bits representation, e4m3fnuz value is half of
# the e4m3fn value, so we should double the scaling factor to
# get the same dequantized value.
# https://onnx.ai/onnx/technical/float8.html
weight_scale = weight_scale * 2.0
if input_scale is not None:
input_scale = input_scale * 2.0
return weight, weight_scale, input_scale