[Feature][Hardware][Amd] Add fp8 Linear Layer for Rocm (#7210)
This commit is contained in:
parent
ec724a725e
commit
e837b624f2
@ -9,6 +9,18 @@
|
||||
|
||||
#include "../../reduction_utils.cuh"
|
||||
|
||||
#ifndef USE_ROCM
|
||||
using FP8_TYPE = c10::Float8_e4m3fn;
|
||||
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
|
||||
std::numeric_limits<FP8_TYPE>::max();
|
||||
#else
|
||||
#include "amd/hip_float8.h"
|
||||
using FP8_TYPE = c10::Float8_e4m3fnuz;
|
||||
// Using the default max value from pytorch (240.0) will cause accuracy
|
||||
// issue when running dynamic quantization. Here use 224.0f for rocm.
|
||||
constexpr auto FP8_E4M3_MAX = 224.0f;
|
||||
#endif
|
||||
|
||||
namespace vllm {
|
||||
|
||||
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||
@ -21,11 +33,9 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||
return old;
|
||||
}
|
||||
|
||||
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
|
||||
|
||||
template <bool is_scale_inverted>
|
||||
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
|
||||
float const val, float const scale) {
|
||||
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
|
||||
float const scale) {
|
||||
float x = 0.0f;
|
||||
if constexpr (is_scale_inverted) {
|
||||
x = val * scale;
|
||||
@ -34,7 +44,13 @@ __device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
|
||||
}
|
||||
|
||||
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
||||
#ifndef USE_ROCM
|
||||
return static_cast<c10::Float8_e4m3fn>(r);
|
||||
#else
|
||||
// Use hardware cvt instruction for fp8 on rocm
|
||||
return c10::Float8_e4m3fnuz(hip_fp8(r).data,
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
#endif
|
||||
}
|
||||
|
||||
// Compute the absolute maximum m of the input tensor and store
|
||||
@ -74,8 +90,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
|
||||
// Finally, since cache[0] contains the maximum for this thread block,
|
||||
// atomically write the max to the target location
|
||||
if (threadIdx.x == 0) {
|
||||
atomicMaxFloat(scale,
|
||||
cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
|
||||
atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX);
|
||||
}
|
||||
}
|
||||
|
||||
@ -88,10 +103,10 @@ struct __align__(8) vec4_t {
|
||||
};
|
||||
|
||||
typedef struct __align__(4) {
|
||||
c10::Float8_e4m3fn x;
|
||||
c10::Float8_e4m3fn y;
|
||||
c10::Float8_e4m3fn z;
|
||||
c10::Float8_e4m3fn w;
|
||||
FP8_TYPE x;
|
||||
FP8_TYPE y;
|
||||
FP8_TYPE z;
|
||||
FP8_TYPE w;
|
||||
}
|
||||
float8x4_t;
|
||||
|
||||
@ -124,7 +139,7 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
|
||||
}
|
||||
|
||||
template <typename scalar_t, bool is_scale_inverted>
|
||||
__device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
|
||||
__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
|
||||
scalar_t const* __restrict__ input,
|
||||
float const scale,
|
||||
int64_t const num_elems,
|
||||
@ -160,7 +175,7 @@ __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
|
||||
__global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
|
||||
const scalar_t* __restrict__ input,
|
||||
const float* __restrict__ scale,
|
||||
int64_t num_elems) {
|
||||
@ -175,7 +190,7 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
||||
c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale,
|
||||
FP8_TYPE* __restrict__ out, float* __restrict__ scale,
|
||||
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
|
||||
const int hidden_size) {
|
||||
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
|
||||
@ -184,7 +199,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
||||
int const token_idx = blockIdx.x;
|
||||
|
||||
scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size];
|
||||
c10::Float8_e4m3fn* __restrict__ token_output = &out[token_idx * hidden_size];
|
||||
FP8_TYPE* __restrict__ token_output = &out[token_idx * hidden_size];
|
||||
|
||||
// For vectorization, token_input and token_output pointers need to be
|
||||
// aligned at 8-byte and 4-byte addresses respectively.
|
||||
@ -241,7 +256,7 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
|
||||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(),
|
||||
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), num_elems);
|
||||
});
|
||||
}
|
||||
@ -261,7 +276,7 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
|
||||
scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems);
|
||||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(),
|
||||
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), num_elems);
|
||||
});
|
||||
}
|
||||
@ -284,7 +299,7 @@ void dynamic_per_token_scaled_fp8_quant(
|
||||
input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
|
||||
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<c10::Float8_e4m3fn>(), scales.data_ptr<float>(),
|
||||
out.data_ptr<FP8_TYPE>(), scales.data_ptr<float>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
||||
hidden_size);
|
||||
|
@ -2,6 +2,13 @@ from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.utils import is_hip
|
||||
|
||||
# Using the default value (240.0) from pytorch will cause accuracy
|
||||
# issue on dynamic quantization models. Here use 224.0 for rocm.
|
||||
ROCM_FP8_MAX = 224.0
|
||||
FP8_DTYPE = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
|
||||
|
||||
|
||||
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
|
||||
return torch.as_tensor(x, dtype=torch.float32, device='cuda')
|
||||
@ -11,13 +18,15 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
|
||||
scale_ub: Optional[torch.tensor] = None) \
|
||||
-> Tuple[torch.tensor, torch.tensor]:
|
||||
|
||||
assert quant_dtype in [torch.int8, torch.float8_e4m3fn]
|
||||
assert quant_dtype in [torch.int8, FP8_DTYPE]
|
||||
if scale_ub is not None:
|
||||
assert quant_dtype == torch.float8_e4m3fn
|
||||
assert quant_dtype == FP8_DTYPE
|
||||
|
||||
qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
|
||||
else torch.finfo(quant_dtype)
|
||||
qtype_max = as_float32_tensor(qtype_traits.max)
|
||||
qtype_traits_max = ROCM_FP8_MAX if is_hip() else qtype_traits.max
|
||||
qtype_traits_min = -ROCM_FP8_MAX if is_hip() else qtype_traits.min
|
||||
qtype_max = as_float32_tensor(qtype_traits_max)
|
||||
s_1 = as_float32_tensor(1.0)
|
||||
s_512 = as_float32_tensor(512.0)
|
||||
|
||||
@ -37,15 +46,15 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
|
||||
iscales = as_float32_tensor(s_1 / scales)
|
||||
torch_out = as_float32_tensor(x) * iscales
|
||||
torch_out = torch_out.round()
|
||||
torch_out = torch_out.clamp(qtype_traits.min,
|
||||
qtype_traits.max).to(quant_dtype)
|
||||
torch_out = torch_out.clamp(qtype_traits_min,
|
||||
qtype_traits_max).to(quant_dtype)
|
||||
else:
|
||||
assert quant_dtype == torch.float8_e4m3fn
|
||||
assert quant_dtype == FP8_DTYPE
|
||||
min_scaling_factor = s_1 / (qtype_max * s_512)
|
||||
scales = scales.clamp(min=min_scaling_factor)
|
||||
torch_out = as_float32_tensor(x) / scales
|
||||
torch_out = torch_out.clamp(qtype_traits.min,
|
||||
qtype_traits.max).to(quant_dtype)
|
||||
torch_out = torch_out.clamp(qtype_traits_min,
|
||||
qtype_traits_max).to(quant_dtype)
|
||||
|
||||
return torch_out, scales
|
||||
|
||||
@ -56,8 +65,10 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
|
||||
def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
|
||||
-> Tuple[torch.tensor, torch.tensor]:
|
||||
|
||||
fp8_traits = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max = as_float32_tensor(fp8_traits.max)
|
||||
fp8_traits = torch.finfo(FP8_DTYPE)
|
||||
fp8_traits_max = ROCM_FP8_MAX if is_hip() else fp8_traits.max
|
||||
fp8_traits_min = -ROCM_FP8_MAX if is_hip() else fp8_traits.min
|
||||
fp8_max = as_float32_tensor(fp8_traits_max)
|
||||
one = as_float32_tensor(1.0)
|
||||
|
||||
# For fp8, in order to match the cuda kernel output, we have to do exactly
|
||||
@ -68,5 +79,5 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
|
||||
ref_scale = x_max / fp8_max
|
||||
ref_iscale = one / ref_scale
|
||||
ref_out = (as_float32_tensor(x) * ref_iscale).clamp(
|
||||
fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn)
|
||||
fp8_traits_min, fp8_traits_max).to(FP8_DTYPE)
|
||||
return ref_out, ref_scale.view((1, ))
|
||||
|
@ -2,7 +2,8 @@ import pytest
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from tests.kernels.quant_utils import (ref_dynamic_per_tensor_fp8_quant,
|
||||
from tests.kernels.quant_utils import (FP8_DTYPE,
|
||||
ref_dynamic_per_tensor_fp8_quant,
|
||||
ref_dynamic_per_token_quant)
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
@ -31,8 +32,7 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
|
||||
|
||||
scale_ub = torch.mean(x).to(dtype=torch.float32, device='cuda') \
|
||||
if scale_ub else None
|
||||
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn,
|
||||
scale_ub)
|
||||
ref_out, ref_scales = ref_dynamic_per_token_quant(x, FP8_DTYPE, scale_ub)
|
||||
ops_out, ops_scales = ops.scaled_fp8_quant(x,
|
||||
scale_ub=scale_ub,
|
||||
use_per_token_if_dynamic=True)
|
||||
|
@ -369,9 +369,12 @@ def scaled_fp8_quant(
|
||||
# This code assumes batch_dim and num_tokens are flattened
|
||||
assert (input.ndim == 2)
|
||||
shape: Union[Tuple[int, int], torch.Size] = input.shape
|
||||
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
|
||||
out_dtype: torch.dtype = torch.float8_e4m3fnuz if vllm.utils.is_hip() \
|
||||
else torch.float8_e4m3fn
|
||||
if num_token_padding:
|
||||
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||
output = torch.empty(shape, device=input.device, dtype=torch.float8_e4m3fn)
|
||||
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
||||
|
||||
if scale is None:
|
||||
if use_per_token_if_dynamic:
|
||||
|
@ -240,7 +240,7 @@ class ModelConfig:
|
||||
|
||||
def _verify_quantization(self) -> None:
|
||||
supported_quantization = [*QUANTIZATION_METHODS]
|
||||
rocm_supported_quantization = ["gptq", "squeezellm"]
|
||||
rocm_supported_quantization = ["gptq", "squeezellm", "fp8"]
|
||||
optimized_quantization_methods = [
|
||||
"fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
|
||||
"fbgemm_fp8", "compressed_tensors", "compressed-tensors"
|
||||
|
@ -20,10 +20,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d, apply_fp8_linear, convert_to_channelwise,
|
||||
create_per_tensor_scale_param, cutlass_fp8_supported,
|
||||
per_tensor_dequantize, requantize_with_max_scale)
|
||||
normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
|
||||
requantize_with_max_scale)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import print_warning_once
|
||||
from vllm.utils import is_hip, print_warning_once
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
@ -120,6 +121,9 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
self.use_marlin = capability < 89 or envs.VLLM_TEST_FORCE_FP8_MARLIN
|
||||
# Disable marlin for rocm
|
||||
if is_hip():
|
||||
self.use_marlin = False
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@ -168,6 +172,8 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
scale = create_per_tensor_scale_param(output_partition_sizes,
|
||||
**extra_weight_attrs)
|
||||
layer.register_parameter("input_scale", scale)
|
||||
else:
|
||||
layer.register_parameter("input_scale", None)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# If checkpoint not serialized fp8, quantize the weights.
|
||||
@ -202,9 +208,23 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
# requantize the logical shards as a single weight.
|
||||
else:
|
||||
# Dequant -> Quant with max scale so we can run per tensor.
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
|
||||
# If rocm, use float8_e4m3fnuz.
|
||||
if is_hip():
|
||||
weight, weight_scale, input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=weight_scale,
|
||||
input_scale=layer.input_scale)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale,
|
||||
requires_grad=False)
|
||||
|
||||
weight_scale, weight = requantize_with_max_scale(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight=weight,
|
||||
weight_scale=weight_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
)
|
||||
|
||||
@ -214,8 +234,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
layer.input_scale = Parameter(layer.input_scale.max(),
|
||||
requires_grad=False)
|
||||
else:
|
||||
layer.input_scale = None
|
||||
|
||||
if self.use_marlin:
|
||||
prepare_fp8_layer_for_marlin(layer)
|
||||
@ -346,10 +364,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# If rocm, use float8_e4m3fnuz as dtype
|
||||
fp8_dtype = torch.float8_e4m3fnuz \
|
||||
if is_hip() else torch.float8_e4m3fn
|
||||
w13_weight = torch.empty_like(layer.w13_weight.data,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
w2_weight = torch.empty_like(layer.w2_weight.data,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
dtype=fp8_dtype)
|
||||
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
||||
|
||||
# Re-initialize w13_scale because we directly quantize
|
||||
# merged w13 weights and generate a single scaling factor.
|
||||
@ -393,6 +413,32 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_input_scale.max(), requires_grad=False)
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
layer.w2_input_scale.max(), requires_grad=False)
|
||||
# If rocm, normalize the weights and scales to e4m3fnuz
|
||||
if is_hip():
|
||||
# Normalize the weights and scales
|
||||
w13_weight, w13_weight_scale, w13_input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
layer.w13_weight, layer.w13_weight_scale,
|
||||
layer.w13_input_scale)
|
||||
w2_weight, w2_weight_scale, w2_input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
layer.w2_weight, layer.w2_weight_scale,
|
||||
layer.w2_input_scale)
|
||||
# Reset the parameter
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight,
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
w13_weight_scale, requires_grad=False)
|
||||
if w13_input_scale is not None:
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
w13_input_scale, requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight,
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
|
||||
requires_grad=False)
|
||||
if w2_input_scale is not None:
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
w2_input_scale, requires_grad=False)
|
||||
|
||||
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
||||
# We take the max then dequant and requant each expert.
|
||||
|
@ -6,9 +6,19 @@ from torch.nn import Parameter
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_hip
|
||||
|
||||
# scaled_mm in pytorch on rocm has a bug that requires always
|
||||
# providing scaling factor for result. This value is created
|
||||
# as global value to avoid multiple tensor allocations, and
|
||||
# can be removed once pytorch fixes the bug.
|
||||
TORCH_SCALED_MM_SCALE_RESULT = torch.ones(1).cuda() if is_hip() else None
|
||||
|
||||
|
||||
def cutlass_fp8_supported() -> bool:
|
||||
# cutlass is not supported on Rocm
|
||||
if is_hip():
|
||||
return False
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
|
||||
@ -147,13 +157,19 @@ def apply_fp8_linear(
|
||||
|
||||
if per_tensor_weights and per_tensor_activations:
|
||||
# Fused GEMM_DQ
|
||||
output, _ = torch._scaled_mm(qinput,
|
||||
output = torch._scaled_mm(
|
||||
qinput,
|
||||
weight,
|
||||
out_dtype=input.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
scale_result=TORCH_SCALED_MM_SCALE_RESULT,
|
||||
bias=bias)
|
||||
# Since in torch 2.5, scaled_mm only returns single value
|
||||
# This should be removed when vllm-nvidia also moves to 2.5
|
||||
if is_hip():
|
||||
return torch.narrow(output, 0, 0, input.shape[0])
|
||||
return torch.narrow(output[0], 0, 0, input.shape[0])
|
||||
|
||||
else:
|
||||
# Fallback for channelwise case, where we use unfused DQ
|
||||
@ -207,3 +223,27 @@ def apply_int8_linear(
|
||||
scale_b=weight_scale,
|
||||
out_dtype=input.dtype,
|
||||
bias=bias)
|
||||
|
||||
|
||||
def normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert weight.dtype == torch.float8_e4m3fn
|
||||
# The bits pattern 10000000(-128) represents zero in e4m3fn
|
||||
# but NaN in e4m3fnuz. So here we set it to 0.
|
||||
# https://onnx.ai/onnx/technical/float8.html
|
||||
weight_as_int8 = weight.view(torch.int8)
|
||||
ROCM_FP8_NAN_AS_INT = -128
|
||||
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
|
||||
weight = weight_as_int8.view(torch.float8_e4m3fnuz)
|
||||
|
||||
# For the same bits representation, e4m3fnuz value is half of
|
||||
# the e4m3fn value, so we should double the scaling factor to
|
||||
# get the same dequantized value.
|
||||
# https://onnx.ai/onnx/technical/float8.html
|
||||
weight_scale = weight_scale * 2.0
|
||||
if input_scale is not None:
|
||||
input_scale = input_scale * 2.0
|
||||
return weight, weight_scale, input_scale
|
||||
|
Loading…
x
Reference in New Issue
Block a user