[Feature][Hardware][Amd] Add fp8 Linear Layer for Rocm (#7210)
This commit is contained in:
parent
ec724a725e
commit
e837b624f2
@ -9,6 +9,18 @@
|
|||||||
|
|
||||||
#include "../../reduction_utils.cuh"
|
#include "../../reduction_utils.cuh"
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
using FP8_TYPE = c10::Float8_e4m3fn;
|
||||||
|
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
|
||||||
|
std::numeric_limits<FP8_TYPE>::max();
|
||||||
|
#else
|
||||||
|
#include "amd/hip_float8.h"
|
||||||
|
using FP8_TYPE = c10::Float8_e4m3fnuz;
|
||||||
|
// Using the default max value from pytorch (240.0) will cause accuracy
|
||||||
|
// issue when running dynamic quantization. Here use 224.0f for rocm.
|
||||||
|
constexpr auto FP8_E4M3_MAX = 224.0f;
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||||
@ -21,11 +33,9 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
|||||||
return old;
|
return old;
|
||||||
}
|
}
|
||||||
|
|
||||||
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
|
|
||||||
|
|
||||||
template <bool is_scale_inverted>
|
template <bool is_scale_inverted>
|
||||||
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
|
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
|
||||||
float const val, float const scale) {
|
float const scale) {
|
||||||
float x = 0.0f;
|
float x = 0.0f;
|
||||||
if constexpr (is_scale_inverted) {
|
if constexpr (is_scale_inverted) {
|
||||||
x = val * scale;
|
x = val * scale;
|
||||||
@ -34,7 +44,13 @@ __device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
|
|||||||
}
|
}
|
||||||
|
|
||||||
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
||||||
|
#ifndef USE_ROCM
|
||||||
return static_cast<c10::Float8_e4m3fn>(r);
|
return static_cast<c10::Float8_e4m3fn>(r);
|
||||||
|
#else
|
||||||
|
// Use hardware cvt instruction for fp8 on rocm
|
||||||
|
return c10::Float8_e4m3fnuz(hip_fp8(r).data,
|
||||||
|
c10::Float8_e4m3fnuz::from_bits());
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute the absolute maximum m of the input tensor and store
|
// Compute the absolute maximum m of the input tensor and store
|
||||||
@ -74,8 +90,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
|
|||||||
// Finally, since cache[0] contains the maximum for this thread block,
|
// Finally, since cache[0] contains the maximum for this thread block,
|
||||||
// atomically write the max to the target location
|
// atomically write the max to the target location
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
atomicMaxFloat(scale,
|
atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX);
|
||||||
cache[0] / std::numeric_limits<c10::Float8_e4m3fn>::max());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -88,10 +103,10 @@ struct __align__(8) vec4_t {
|
|||||||
};
|
};
|
||||||
|
|
||||||
typedef struct __align__(4) {
|
typedef struct __align__(4) {
|
||||||
c10::Float8_e4m3fn x;
|
FP8_TYPE x;
|
||||||
c10::Float8_e4m3fn y;
|
FP8_TYPE y;
|
||||||
c10::Float8_e4m3fn z;
|
FP8_TYPE z;
|
||||||
c10::Float8_e4m3fn w;
|
FP8_TYPE w;
|
||||||
}
|
}
|
||||||
float8x4_t;
|
float8x4_t;
|
||||||
|
|
||||||
@ -124,7 +139,7 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t, bool is_scale_inverted>
|
template <typename scalar_t, bool is_scale_inverted>
|
||||||
__device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
|
__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
|
||||||
scalar_t const* __restrict__ input,
|
scalar_t const* __restrict__ input,
|
||||||
float const scale,
|
float const scale,
|
||||||
int64_t const num_elems,
|
int64_t const num_elems,
|
||||||
@ -160,7 +175,7 @@ __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
|
__global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
|
||||||
const scalar_t* __restrict__ input,
|
const scalar_t* __restrict__ input,
|
||||||
const float* __restrict__ scale,
|
const float* __restrict__ scale,
|
||||||
int64_t num_elems) {
|
int64_t num_elems) {
|
||||||
@ -175,7 +190,7 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
|
|||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
||||||
c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale,
|
FP8_TYPE* __restrict__ out, float* __restrict__ scale,
|
||||||
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
|
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
|
||||||
const int hidden_size) {
|
const int hidden_size) {
|
||||||
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
|
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
|
||||||
@ -184,7 +199,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
|||||||
int const token_idx = blockIdx.x;
|
int const token_idx = blockIdx.x;
|
||||||
|
|
||||||
scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size];
|
scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size];
|
||||||
c10::Float8_e4m3fn* __restrict__ token_output = &out[token_idx * hidden_size];
|
FP8_TYPE* __restrict__ token_output = &out[token_idx * hidden_size];
|
||||||
|
|
||||||
// For vectorization, token_input and token_output pointers need to be
|
// For vectorization, token_input and token_output pointers need to be
|
||||||
// aligned at 8-byte and 4-byte addresses respectively.
|
// aligned at 8-byte and 4-byte addresses respectively.
|
||||||
@ -241,7 +256,7 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
|||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
|
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
|
||||||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(),
|
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
|
||||||
scale.data_ptr<float>(), num_elems);
|
scale.data_ptr<float>(), num_elems);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -261,7 +276,7 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
|||||||
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
|
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems);
|
scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems);
|
||||||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
out.data_ptr<c10::Float8_e4m3fn>(), input.data_ptr<scalar_t>(),
|
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
|
||||||
scale.data_ptr<float>(), num_elems);
|
scale.data_ptr<float>(), num_elems);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -284,7 +299,7 @@ void dynamic_per_token_scaled_fp8_quant(
|
|||||||
input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
|
input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
|
||||||
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
|
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
|
||||||
<<<grid, block, 0, stream>>>(
|
<<<grid, block, 0, stream>>>(
|
||||||
out.data_ptr<c10::Float8_e4m3fn>(), scales.data_ptr<float>(),
|
out.data_ptr<FP8_TYPE>(), scales.data_ptr<float>(),
|
||||||
input.data_ptr<scalar_t>(),
|
input.data_ptr<scalar_t>(),
|
||||||
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
||||||
hidden_size);
|
hidden_size);
|
||||||
|
@ -2,6 +2,13 @@ from typing import Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.utils import is_hip
|
||||||
|
|
||||||
|
# Using the default value (240.0) from pytorch will cause accuracy
|
||||||
|
# issue on dynamic quantization models. Here use 224.0 for rocm.
|
||||||
|
ROCM_FP8_MAX = 224.0
|
||||||
|
FP8_DTYPE = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
|
||||||
|
|
||||||
|
|
||||||
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
|
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
|
||||||
return torch.as_tensor(x, dtype=torch.float32, device='cuda')
|
return torch.as_tensor(x, dtype=torch.float32, device='cuda')
|
||||||
@ -11,13 +18,15 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
|
|||||||
scale_ub: Optional[torch.tensor] = None) \
|
scale_ub: Optional[torch.tensor] = None) \
|
||||||
-> Tuple[torch.tensor, torch.tensor]:
|
-> Tuple[torch.tensor, torch.tensor]:
|
||||||
|
|
||||||
assert quant_dtype in [torch.int8, torch.float8_e4m3fn]
|
assert quant_dtype in [torch.int8, FP8_DTYPE]
|
||||||
if scale_ub is not None:
|
if scale_ub is not None:
|
||||||
assert quant_dtype == torch.float8_e4m3fn
|
assert quant_dtype == FP8_DTYPE
|
||||||
|
|
||||||
qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
|
qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
|
||||||
else torch.finfo(quant_dtype)
|
else torch.finfo(quant_dtype)
|
||||||
qtype_max = as_float32_tensor(qtype_traits.max)
|
qtype_traits_max = ROCM_FP8_MAX if is_hip() else qtype_traits.max
|
||||||
|
qtype_traits_min = -ROCM_FP8_MAX if is_hip() else qtype_traits.min
|
||||||
|
qtype_max = as_float32_tensor(qtype_traits_max)
|
||||||
s_1 = as_float32_tensor(1.0)
|
s_1 = as_float32_tensor(1.0)
|
||||||
s_512 = as_float32_tensor(512.0)
|
s_512 = as_float32_tensor(512.0)
|
||||||
|
|
||||||
@ -37,15 +46,15 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
|
|||||||
iscales = as_float32_tensor(s_1 / scales)
|
iscales = as_float32_tensor(s_1 / scales)
|
||||||
torch_out = as_float32_tensor(x) * iscales
|
torch_out = as_float32_tensor(x) * iscales
|
||||||
torch_out = torch_out.round()
|
torch_out = torch_out.round()
|
||||||
torch_out = torch_out.clamp(qtype_traits.min,
|
torch_out = torch_out.clamp(qtype_traits_min,
|
||||||
qtype_traits.max).to(quant_dtype)
|
qtype_traits_max).to(quant_dtype)
|
||||||
else:
|
else:
|
||||||
assert quant_dtype == torch.float8_e4m3fn
|
assert quant_dtype == FP8_DTYPE
|
||||||
min_scaling_factor = s_1 / (qtype_max * s_512)
|
min_scaling_factor = s_1 / (qtype_max * s_512)
|
||||||
scales = scales.clamp(min=min_scaling_factor)
|
scales = scales.clamp(min=min_scaling_factor)
|
||||||
torch_out = as_float32_tensor(x) / scales
|
torch_out = as_float32_tensor(x) / scales
|
||||||
torch_out = torch_out.clamp(qtype_traits.min,
|
torch_out = torch_out.clamp(qtype_traits_min,
|
||||||
qtype_traits.max).to(quant_dtype)
|
qtype_traits_max).to(quant_dtype)
|
||||||
|
|
||||||
return torch_out, scales
|
return torch_out, scales
|
||||||
|
|
||||||
@ -56,8 +65,10 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
|
|||||||
def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
|
def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
|
||||||
-> Tuple[torch.tensor, torch.tensor]:
|
-> Tuple[torch.tensor, torch.tensor]:
|
||||||
|
|
||||||
fp8_traits = torch.finfo(torch.float8_e4m3fn)
|
fp8_traits = torch.finfo(FP8_DTYPE)
|
||||||
fp8_max = as_float32_tensor(fp8_traits.max)
|
fp8_traits_max = ROCM_FP8_MAX if is_hip() else fp8_traits.max
|
||||||
|
fp8_traits_min = -ROCM_FP8_MAX if is_hip() else fp8_traits.min
|
||||||
|
fp8_max = as_float32_tensor(fp8_traits_max)
|
||||||
one = as_float32_tensor(1.0)
|
one = as_float32_tensor(1.0)
|
||||||
|
|
||||||
# For fp8, in order to match the cuda kernel output, we have to do exactly
|
# For fp8, in order to match the cuda kernel output, we have to do exactly
|
||||||
@ -68,5 +79,5 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
|
|||||||
ref_scale = x_max / fp8_max
|
ref_scale = x_max / fp8_max
|
||||||
ref_iscale = one / ref_scale
|
ref_iscale = one / ref_scale
|
||||||
ref_out = (as_float32_tensor(x) * ref_iscale).clamp(
|
ref_out = (as_float32_tensor(x) * ref_iscale).clamp(
|
||||||
fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn)
|
fp8_traits_min, fp8_traits_max).to(FP8_DTYPE)
|
||||||
return ref_out, ref_scale.view((1, ))
|
return ref_out, ref_scale.view((1, ))
|
||||||
|
@ -2,7 +2,8 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm._custom_ops as ops
|
import vllm._custom_ops as ops
|
||||||
from tests.kernels.quant_utils import (ref_dynamic_per_tensor_fp8_quant,
|
from tests.kernels.quant_utils import (FP8_DTYPE,
|
||||||
|
ref_dynamic_per_tensor_fp8_quant,
|
||||||
ref_dynamic_per_token_quant)
|
ref_dynamic_per_token_quant)
|
||||||
|
|
||||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
@ -31,8 +32,7 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
|
|||||||
|
|
||||||
scale_ub = torch.mean(x).to(dtype=torch.float32, device='cuda') \
|
scale_ub = torch.mean(x).to(dtype=torch.float32, device='cuda') \
|
||||||
if scale_ub else None
|
if scale_ub else None
|
||||||
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn,
|
ref_out, ref_scales = ref_dynamic_per_token_quant(x, FP8_DTYPE, scale_ub)
|
||||||
scale_ub)
|
|
||||||
ops_out, ops_scales = ops.scaled_fp8_quant(x,
|
ops_out, ops_scales = ops.scaled_fp8_quant(x,
|
||||||
scale_ub=scale_ub,
|
scale_ub=scale_ub,
|
||||||
use_per_token_if_dynamic=True)
|
use_per_token_if_dynamic=True)
|
||||||
|
@ -369,9 +369,12 @@ def scaled_fp8_quant(
|
|||||||
# This code assumes batch_dim and num_tokens are flattened
|
# This code assumes batch_dim and num_tokens are flattened
|
||||||
assert (input.ndim == 2)
|
assert (input.ndim == 2)
|
||||||
shape: Union[Tuple[int, int], torch.Size] = input.shape
|
shape: Union[Tuple[int, int], torch.Size] = input.shape
|
||||||
|
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
|
||||||
|
out_dtype: torch.dtype = torch.float8_e4m3fnuz if vllm.utils.is_hip() \
|
||||||
|
else torch.float8_e4m3fn
|
||||||
if num_token_padding:
|
if num_token_padding:
|
||||||
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||||
output = torch.empty(shape, device=input.device, dtype=torch.float8_e4m3fn)
|
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
||||||
|
|
||||||
if scale is None:
|
if scale is None:
|
||||||
if use_per_token_if_dynamic:
|
if use_per_token_if_dynamic:
|
||||||
|
@ -240,7 +240,7 @@ class ModelConfig:
|
|||||||
|
|
||||||
def _verify_quantization(self) -> None:
|
def _verify_quantization(self) -> None:
|
||||||
supported_quantization = [*QUANTIZATION_METHODS]
|
supported_quantization = [*QUANTIZATION_METHODS]
|
||||||
rocm_supported_quantization = ["gptq", "squeezellm"]
|
rocm_supported_quantization = ["gptq", "squeezellm", "fp8"]
|
||||||
optimized_quantization_methods = [
|
optimized_quantization_methods = [
|
||||||
"fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
|
"fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
|
||||||
"fbgemm_fp8", "compressed_tensors", "compressed-tensors"
|
"fbgemm_fp8", "compressed_tensors", "compressed-tensors"
|
||||||
|
@ -20,10 +20,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
all_close_1d, apply_fp8_linear, convert_to_channelwise,
|
all_close_1d, apply_fp8_linear, convert_to_channelwise,
|
||||||
create_per_tensor_scale_param, cutlass_fp8_supported,
|
create_per_tensor_scale_param, cutlass_fp8_supported,
|
||||||
per_tensor_dequantize, requantize_with_max_scale)
|
normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize,
|
||||||
|
requantize_with_max_scale)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import print_warning_once
|
from vllm.utils import is_hip, print_warning_once
|
||||||
|
|
||||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||||
|
|
||||||
@ -120,6 +121,9 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
capability = current_platform.get_device_capability()
|
capability = current_platform.get_device_capability()
|
||||||
capability = capability[0] * 10 + capability[1]
|
capability = capability[0] * 10 + capability[1]
|
||||||
self.use_marlin = capability < 89 or envs.VLLM_TEST_FORCE_FP8_MARLIN
|
self.use_marlin = capability < 89 or envs.VLLM_TEST_FORCE_FP8_MARLIN
|
||||||
|
# Disable marlin for rocm
|
||||||
|
if is_hip():
|
||||||
|
self.use_marlin = False
|
||||||
|
|
||||||
def create_weights(
|
def create_weights(
|
||||||
self,
|
self,
|
||||||
@ -168,6 +172,8 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
scale = create_per_tensor_scale_param(output_partition_sizes,
|
scale = create_per_tensor_scale_param(output_partition_sizes,
|
||||||
**extra_weight_attrs)
|
**extra_weight_attrs)
|
||||||
layer.register_parameter("input_scale", scale)
|
layer.register_parameter("input_scale", scale)
|
||||||
|
else:
|
||||||
|
layer.register_parameter("input_scale", None)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
# If checkpoint not serialized fp8, quantize the weights.
|
# If checkpoint not serialized fp8, quantize the weights.
|
||||||
@ -202,9 +208,23 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
# requantize the logical shards as a single weight.
|
# requantize the logical shards as a single weight.
|
||||||
else:
|
else:
|
||||||
# Dequant -> Quant with max scale so we can run per tensor.
|
# Dequant -> Quant with max scale so we can run per tensor.
|
||||||
|
weight = layer.weight
|
||||||
|
weight_scale = layer.weight_scale
|
||||||
|
|
||||||
|
# If rocm, use float8_e4m3fnuz.
|
||||||
|
if is_hip():
|
||||||
|
weight, weight_scale, input_scale = \
|
||||||
|
normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
weight=weight,
|
||||||
|
weight_scale=weight_scale,
|
||||||
|
input_scale=layer.input_scale)
|
||||||
|
if input_scale is not None:
|
||||||
|
layer.input_scale = Parameter(input_scale,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
weight_scale, weight = requantize_with_max_scale(
|
weight_scale, weight = requantize_with_max_scale(
|
||||||
weight=layer.weight,
|
weight=weight,
|
||||||
weight_scale=layer.weight_scale,
|
weight_scale=weight_scale,
|
||||||
logical_widths=layer.logical_widths,
|
logical_widths=layer.logical_widths,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -214,8 +234,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
if self.quant_config.activation_scheme == "static":
|
if self.quant_config.activation_scheme == "static":
|
||||||
layer.input_scale = Parameter(layer.input_scale.max(),
|
layer.input_scale = Parameter(layer.input_scale.max(),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
else:
|
|
||||||
layer.input_scale = None
|
|
||||||
|
|
||||||
if self.use_marlin:
|
if self.use_marlin:
|
||||||
prepare_fp8_layer_for_marlin(layer)
|
prepare_fp8_layer_for_marlin(layer)
|
||||||
@ -346,10 +364,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
# If checkpoint is fp16, quantize in place.
|
# If checkpoint is fp16, quantize in place.
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
|
# If rocm, use float8_e4m3fnuz as dtype
|
||||||
|
fp8_dtype = torch.float8_e4m3fnuz \
|
||||||
|
if is_hip() else torch.float8_e4m3fn
|
||||||
w13_weight = torch.empty_like(layer.w13_weight.data,
|
w13_weight = torch.empty_like(layer.w13_weight.data,
|
||||||
dtype=torch.float8_e4m3fn)
|
dtype=fp8_dtype)
|
||||||
w2_weight = torch.empty_like(layer.w2_weight.data,
|
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
||||||
dtype=torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
# Re-initialize w13_scale because we directly quantize
|
# Re-initialize w13_scale because we directly quantize
|
||||||
# merged w13 weights and generate a single scaling factor.
|
# merged w13 weights and generate a single scaling factor.
|
||||||
@ -393,6 +413,32 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w13_input_scale.max(), requires_grad=False)
|
layer.w13_input_scale.max(), requires_grad=False)
|
||||||
layer.w2_input_scale = torch.nn.Parameter(
|
layer.w2_input_scale = torch.nn.Parameter(
|
||||||
layer.w2_input_scale.max(), requires_grad=False)
|
layer.w2_input_scale.max(), requires_grad=False)
|
||||||
|
# If rocm, normalize the weights and scales to e4m3fnuz
|
||||||
|
if is_hip():
|
||||||
|
# Normalize the weights and scales
|
||||||
|
w13_weight, w13_weight_scale, w13_input_scale = \
|
||||||
|
normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
layer.w13_weight, layer.w13_weight_scale,
|
||||||
|
layer.w13_input_scale)
|
||||||
|
w2_weight, w2_weight_scale, w2_input_scale = \
|
||||||
|
normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
layer.w2_weight, layer.w2_weight_scale,
|
||||||
|
layer.w2_input_scale)
|
||||||
|
# Reset the parameter
|
||||||
|
layer.w13_weight = torch.nn.Parameter(w13_weight,
|
||||||
|
requires_grad=False)
|
||||||
|
layer.w13_weight_scale = torch.nn.Parameter(
|
||||||
|
w13_weight_scale, requires_grad=False)
|
||||||
|
if w13_input_scale is not None:
|
||||||
|
layer.w13_input_scale = torch.nn.Parameter(
|
||||||
|
w13_input_scale, requires_grad=False)
|
||||||
|
layer.w2_weight = torch.nn.Parameter(w2_weight,
|
||||||
|
requires_grad=False)
|
||||||
|
layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
|
||||||
|
requires_grad=False)
|
||||||
|
if w2_input_scale is not None:
|
||||||
|
layer.w2_input_scale = torch.nn.Parameter(
|
||||||
|
w2_input_scale, requires_grad=False)
|
||||||
|
|
||||||
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
||||||
# We take the max then dequant and requant each expert.
|
# We take the max then dequant and requant each expert.
|
||||||
|
@ -6,9 +6,19 @@ from torch.nn import Parameter
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils import is_hip
|
||||||
|
|
||||||
|
# scaled_mm in pytorch on rocm has a bug that requires always
|
||||||
|
# providing scaling factor for result. This value is created
|
||||||
|
# as global value to avoid multiple tensor allocations, and
|
||||||
|
# can be removed once pytorch fixes the bug.
|
||||||
|
TORCH_SCALED_MM_SCALE_RESULT = torch.ones(1).cuda() if is_hip() else None
|
||||||
|
|
||||||
|
|
||||||
def cutlass_fp8_supported() -> bool:
|
def cutlass_fp8_supported() -> bool:
|
||||||
|
# cutlass is not supported on Rocm
|
||||||
|
if is_hip():
|
||||||
|
return False
|
||||||
capability = current_platform.get_device_capability()
|
capability = current_platform.get_device_capability()
|
||||||
capability = capability[0] * 10 + capability[1]
|
capability = capability[0] * 10 + capability[1]
|
||||||
|
|
||||||
@ -147,13 +157,19 @@ def apply_fp8_linear(
|
|||||||
|
|
||||||
if per_tensor_weights and per_tensor_activations:
|
if per_tensor_weights and per_tensor_activations:
|
||||||
# Fused GEMM_DQ
|
# Fused GEMM_DQ
|
||||||
output, _ = torch._scaled_mm(qinput,
|
output = torch._scaled_mm(
|
||||||
|
qinput,
|
||||||
weight,
|
weight,
|
||||||
out_dtype=input.dtype,
|
out_dtype=input.dtype,
|
||||||
scale_a=x_scale,
|
scale_a=x_scale,
|
||||||
scale_b=weight_scale,
|
scale_b=weight_scale,
|
||||||
|
scale_result=TORCH_SCALED_MM_SCALE_RESULT,
|
||||||
bias=bias)
|
bias=bias)
|
||||||
|
# Since in torch 2.5, scaled_mm only returns single value
|
||||||
|
# This should be removed when vllm-nvidia also moves to 2.5
|
||||||
|
if is_hip():
|
||||||
return torch.narrow(output, 0, 0, input.shape[0])
|
return torch.narrow(output, 0, 0, input.shape[0])
|
||||||
|
return torch.narrow(output[0], 0, 0, input.shape[0])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Fallback for channelwise case, where we use unfused DQ
|
# Fallback for channelwise case, where we use unfused DQ
|
||||||
@ -207,3 +223,27 @@ def apply_int8_linear(
|
|||||||
scale_b=weight_scale,
|
scale_b=weight_scale,
|
||||||
out_dtype=input.dtype,
|
out_dtype=input.dtype,
|
||||||
bias=bias)
|
bias=bias)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
input_scale: Optional[torch.Tensor] = None
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
assert weight.dtype == torch.float8_e4m3fn
|
||||||
|
# The bits pattern 10000000(-128) represents zero in e4m3fn
|
||||||
|
# but NaN in e4m3fnuz. So here we set it to 0.
|
||||||
|
# https://onnx.ai/onnx/technical/float8.html
|
||||||
|
weight_as_int8 = weight.view(torch.int8)
|
||||||
|
ROCM_FP8_NAN_AS_INT = -128
|
||||||
|
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
|
||||||
|
weight = weight_as_int8.view(torch.float8_e4m3fnuz)
|
||||||
|
|
||||||
|
# For the same bits representation, e4m3fnuz value is half of
|
||||||
|
# the e4m3fn value, so we should double the scaling factor to
|
||||||
|
# get the same dequantized value.
|
||||||
|
# https://onnx.ai/onnx/technical/float8.html
|
||||||
|
weight_scale = weight_scale * 2.0
|
||||||
|
if input_scale is not None:
|
||||||
|
input_scale = input_scale * 2.0
|
||||||
|
return weight, weight_scale, input_scale
|
||||||
|
Loading…
x
Reference in New Issue
Block a user