[ Kernel ] FP8 Dynamic Per Token Quant - Add scale_ub (#6593)
Co-authored-by: Varun Sundar Rabindranth <varun@neuralmagic.com>
This commit is contained in:
parent
e81522e879
commit
2e26564259
@ -134,9 +134,9 @@ void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
torch::Tensor& scale);
|
||||
|
||||
void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out,
|
||||
torch::Tensor const& input,
|
||||
torch::Tensor& scale);
|
||||
void dynamic_per_token_scaled_fp8_quant(
|
||||
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
|
||||
c10::optional<torch::Tensor> const& scale_ub);
|
||||
|
||||
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
int64_t block_size, torch::Tensor sorted_token_ids,
|
||||
|
@ -23,10 +23,16 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||
|
||||
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
|
||||
|
||||
template <typename scalar_t>
|
||||
template <bool is_scale_inverted>
|
||||
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
|
||||
const scalar_t val, const float inverted_scale) {
|
||||
float x = static_cast<float>(val) * inverted_scale;
|
||||
float const val, float const scale) {
|
||||
float x = 0.0f;
|
||||
if constexpr (is_scale_inverted) {
|
||||
x = val * scale;
|
||||
} else {
|
||||
x = val / scale;
|
||||
}
|
||||
|
||||
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
||||
return static_cast<c10::Float8_e4m3fn>(r);
|
||||
}
|
||||
@ -117,10 +123,10 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
|
||||
return absmax_val;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
template <typename scalar_t, bool is_scale_inverted>
|
||||
__device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
|
||||
scalar_t const* __restrict__ input,
|
||||
float const inverted_scale,
|
||||
float const scale,
|
||||
int64_t const num_elems,
|
||||
int const tid, int const step) {
|
||||
// Vectorized input/output to better utilize memory bandwidth.
|
||||
@ -135,16 +141,21 @@ __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
|
||||
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
||||
float8x4_t out_vec;
|
||||
|
||||
out_vec.x = scaled_fp8_conversion(in_vec.x, inverted_scale);
|
||||
out_vec.y = scaled_fp8_conversion(in_vec.y, inverted_scale);
|
||||
out_vec.z = scaled_fp8_conversion(in_vec.z, inverted_scale);
|
||||
out_vec.w = scaled_fp8_conversion(in_vec.w, inverted_scale);
|
||||
out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
|
||||
static_cast<float>(in_vec.x), scale);
|
||||
out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
|
||||
static_cast<float>(in_vec.y), scale);
|
||||
out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
|
||||
static_cast<float>(in_vec.z), scale);
|
||||
out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
|
||||
static_cast<float>(in_vec.w), scale);
|
||||
vectorized_out[i] = out_vec;
|
||||
}
|
||||
|
||||
// Handle the remaining elements if num_elems is not divisible by 4
|
||||
for (int i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
||||
out[i] = scaled_fp8_conversion(input[i], inverted_scale);
|
||||
out[i] = scaled_fp8_conversion<is_scale_inverted>(
|
||||
static_cast<float>(input[i]), scale);
|
||||
}
|
||||
}
|
||||
|
||||
@ -158,15 +169,17 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
|
||||
// Invert the scale so that we can use multiplications to avoid expensive
|
||||
// division.
|
||||
const float inverted_scale = 1.0f / (*scale);
|
||||
|
||||
scaled_fp8_conversion_vec(out, input, inverted_scale, num_elems, tid,
|
||||
blockDim.x * gridDim.x);
|
||||
scaled_fp8_conversion_vec<scalar_t, true>(
|
||||
out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
||||
c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale,
|
||||
scalar_t const* __restrict__ input, const int hidden_size) {
|
||||
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
|
||||
const int hidden_size) {
|
||||
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
|
||||
|
||||
int const tid = threadIdx.x;
|
||||
int const token_idx = blockIdx.x;
|
||||
|
||||
@ -188,20 +201,27 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
||||
}
|
||||
|
||||
float const block_absmax_val_maybe = blockReduceMax(absmax_val);
|
||||
__shared__ float block_absmax_val;
|
||||
__shared__ float token_scale;
|
||||
if (tid == 0) {
|
||||
block_absmax_val = block_absmax_val_maybe;
|
||||
scale[token_idx] = block_absmax_val / FP8_E4M3_MAX;
|
||||
if (scale_ub) {
|
||||
token_scale = min(block_absmax_val_maybe, *scale_ub);
|
||||
} else {
|
||||
token_scale = block_absmax_val_maybe;
|
||||
}
|
||||
// token scale computation
|
||||
token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor);
|
||||
scale[token_idx] = token_scale;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float const inverted_scale = FP8_E4M3_MAX / block_absmax_val;
|
||||
// Note that we don't use inverted scales so we can match FBGemm impl.
|
||||
if (can_vectorize) {
|
||||
scaled_fp8_conversion_vec(token_output, token_input, inverted_scale,
|
||||
hidden_size, tid, blockDim.x);
|
||||
scaled_fp8_conversion_vec<scalar_t, false>(
|
||||
token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
|
||||
} else {
|
||||
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||
token_output[i] = scaled_fp8_conversion(token_input[i], inverted_scale);
|
||||
token_output[i] = scaled_fp8_conversion<false>(
|
||||
static_cast<float>(token_input[i]), token_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -246,9 +266,10 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
});
|
||||
}
|
||||
|
||||
void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor const& input, // [..., d]
|
||||
torch::Tensor& scales) {
|
||||
void dynamic_per_token_scaled_fp8_quant(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor const& input, // [..., d]
|
||||
torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) {
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
|
||||
@ -264,6 +285,8 @@ void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<c10::Float8_e4m3fn>(), scales.data_ptr<float>(),
|
||||
input.data_ptr<scalar_t>(), hidden_size);
|
||||
input.data_ptr<scalar_t>(),
|
||||
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
||||
hidden_size);
|
||||
});
|
||||
}
|
||||
|
@ -188,7 +188,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
|
||||
ops.def(
|
||||
"dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! "
|
||||
"scale) -> "
|
||||
"scale, Tensor? scale_ub) -> "
|
||||
"()");
|
||||
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
|
||||
&dynamic_per_token_scaled_fp8_quant);
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Tuple, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -7,13 +7,19 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
|
||||
return torch.as_tensor(x, dtype=torch.float32, device='cuda')
|
||||
|
||||
def ref_dynamic_per_token_quant(x: torch.tensor,
|
||||
quant_dtype: torch.dtype) \
|
||||
quant_dtype: torch.dtype,
|
||||
scale_ub: Optional[torch.tensor] = None) \
|
||||
-> Tuple[torch.tensor, torch.tensor]:
|
||||
|
||||
assert quant_dtype in [torch.int8, torch.float8_e4m3fn]
|
||||
if scale_ub is not None:
|
||||
assert quant_dtype == torch.float8_e4m3fn
|
||||
|
||||
qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
|
||||
else torch.finfo(quant_dtype)
|
||||
qtype_max = as_float32_tensor(qtype_traits.max)
|
||||
s_1 = as_float32_tensor(1.0)
|
||||
s_512 = as_float32_tensor(512.0)
|
||||
|
||||
# For fp8, in order to match the cuda kernel output, we have to do exactly
|
||||
# the same operations as in the corresponding fp8 kernel to prevent
|
||||
@ -22,14 +28,24 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
|
||||
# Compute scales
|
||||
x_token_max, _ = x.abs().max(dim=-1)
|
||||
x_token_max = as_float32_tensor(x_token_max)
|
||||
if scale_ub is not None:
|
||||
x_token_max = x_token_max.clamp(max=scale_ub)
|
||||
scales = (x_token_max / qtype_max)[:, None]
|
||||
|
||||
# Quant
|
||||
iscales = (qtype_max / x_token_max)[:, None]
|
||||
torch_out = as_float32_tensor(x) * iscales
|
||||
torch_out = torch_out.round() if quant_dtype == torch.int8 else torch_out
|
||||
torch_out = torch_out.clamp(qtype_traits.min,
|
||||
qtype_traits.max).to(quant_dtype)
|
||||
if quant_dtype == torch.int8:
|
||||
iscales = as_float32_tensor(s_1 / scales)
|
||||
torch_out = as_float32_tensor(x) * iscales
|
||||
torch_out = torch_out.round()
|
||||
torch_out = torch_out.clamp(qtype_traits.min,
|
||||
qtype_traits.max).to(quant_dtype)
|
||||
else:
|
||||
assert quant_dtype == torch.float8_e4m3fn
|
||||
min_scaling_factor = s_1 / (qtype_max * s_512)
|
||||
scales = scales.clamp(min=min_scaling_factor)
|
||||
torch_out = as_float32_tensor(x) / scales
|
||||
torch_out = torch_out.clamp(qtype_traits.min,
|
||||
qtype_traits.max).to(quant_dtype)
|
||||
|
||||
return torch_out, scales
|
||||
|
||||
|
@ -10,24 +10,31 @@ HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192,
|
||||
8193] # Arbitrary values for testing
|
||||
HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases
|
||||
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
|
||||
SCALE_UBS = [True, False]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("scale_ub", SCALE_UBS)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int) -> None:
|
||||
dtype: torch.dtype, scale_ub: bool,
|
||||
seed: int) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
|
||||
device="cuda") + 1e-6 # avoid nans
|
||||
|
||||
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn)
|
||||
scale_ub = torch.mean(x).to(dtype=torch.float32, device='cuda') \
|
||||
if scale_ub else None
|
||||
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn,
|
||||
scale_ub)
|
||||
ops_out, ops_scales = ops.scaled_fp8_quant(x,
|
||||
scale_ub=scale_ub,
|
||||
use_per_token_if_dynamic=True)
|
||||
|
||||
assert torch.allclose(ref_scales, ops_scales)
|
||||
|
@ -300,6 +300,7 @@ def scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
batch_dim_padding: Optional[int] = None,
|
||||
scale_ub: Optional[torch.Tensor] = None,
|
||||
use_per_token_if_dynamic: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@ -336,7 +337,7 @@ def scaled_fp8_quant(
|
||||
device=input.device,
|
||||
dtype=torch.float32)
|
||||
torch.ops._C.dynamic_per_token_scaled_fp8_quant(
|
||||
output, input, scale)
|
||||
output, input, scale, scale_ub)
|
||||
else:
|
||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
|
||||
|
Loading…
x
Reference in New Issue
Block a user