[ Kernel ] FP8 Dynamic Per Token Quant - Add scale_ub (#6593)

Co-authored-by: Varun Sundar Rabindranth <varun@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath 2024-07-19 21:15:26 -04:00 committed by GitHub
parent e81522e879
commit 2e26564259
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 86 additions and 39 deletions

View File

@ -134,9 +134,9 @@ void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& scale); torch::Tensor& scale);
void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, void dynamic_per_token_scaled_fp8_quant(
torch::Tensor const& input, torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
torch::Tensor& scale); c10::optional<torch::Tensor> const& scale_ub);
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids, int64_t block_size, torch::Tensor sorted_token_ids,

View File

@ -23,10 +23,16 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max() #define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
template <typename scalar_t> template <bool is_scale_inverted>
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion( __device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
const scalar_t val, const float inverted_scale) { float const val, float const scale) {
float x = static_cast<float>(val) * inverted_scale; float x = 0.0f;
if constexpr (is_scale_inverted) {
x = val * scale;
} else {
x = val / scale;
}
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
return static_cast<c10::Float8_e4m3fn>(r); return static_cast<c10::Float8_e4m3fn>(r);
} }
@ -117,10 +123,10 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
return absmax_val; return absmax_val;
} }
template <typename scalar_t> template <typename scalar_t, bool is_scale_inverted>
__device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out, __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
scalar_t const* __restrict__ input, scalar_t const* __restrict__ input,
float const inverted_scale, float const scale,
int64_t const num_elems, int64_t const num_elems,
int const tid, int const step) { int const tid, int const step) {
// Vectorized input/output to better utilize memory bandwidth. // Vectorized input/output to better utilize memory bandwidth.
@ -135,16 +141,21 @@ __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
vec4_t<scalar_t> in_vec = vectorized_in[i]; vec4_t<scalar_t> in_vec = vectorized_in[i];
float8x4_t out_vec; float8x4_t out_vec;
out_vec.x = scaled_fp8_conversion(in_vec.x, inverted_scale); out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
out_vec.y = scaled_fp8_conversion(in_vec.y, inverted_scale); static_cast<float>(in_vec.x), scale);
out_vec.z = scaled_fp8_conversion(in_vec.z, inverted_scale); out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
out_vec.w = scaled_fp8_conversion(in_vec.w, inverted_scale); static_cast<float>(in_vec.y), scale);
out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.z), scale);
out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.w), scale);
vectorized_out[i] = out_vec; vectorized_out[i] = out_vec;
} }
// Handle the remaining elements if num_elems is not divisible by 4 // Handle the remaining elements if num_elems is not divisible by 4
for (int i = num_vec_elems * 4 + tid; i < num_elems; i += step) { for (int i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
out[i] = scaled_fp8_conversion(input[i], inverted_scale); out[i] = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(input[i]), scale);
} }
} }
@ -158,15 +169,17 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
// Invert the scale so that we can use multiplications to avoid expensive // Invert the scale so that we can use multiplications to avoid expensive
// division. // division.
const float inverted_scale = 1.0f / (*scale); const float inverted_scale = 1.0f / (*scale);
scaled_fp8_conversion_vec<scalar_t, true>(
scaled_fp8_conversion_vec(out, input, inverted_scale, num_elems, tid, out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
blockDim.x * gridDim.x);
} }
template <typename scalar_t> template <typename scalar_t>
__global__ void dynamic_per_token_scaled_fp8_quant_kernel( __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale, c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale,
scalar_t const* __restrict__ input, const int hidden_size) { scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
const int hidden_size) {
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
int const tid = threadIdx.x; int const tid = threadIdx.x;
int const token_idx = blockIdx.x; int const token_idx = blockIdx.x;
@ -188,20 +201,27 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
} }
float const block_absmax_val_maybe = blockReduceMax(absmax_val); float const block_absmax_val_maybe = blockReduceMax(absmax_val);
__shared__ float block_absmax_val; __shared__ float token_scale;
if (tid == 0) { if (tid == 0) {
block_absmax_val = block_absmax_val_maybe; if (scale_ub) {
scale[token_idx] = block_absmax_val / FP8_E4M3_MAX; token_scale = min(block_absmax_val_maybe, *scale_ub);
} else {
token_scale = block_absmax_val_maybe;
}
// token scale computation
token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor);
scale[token_idx] = token_scale;
} }
__syncthreads(); __syncthreads();
float const inverted_scale = FP8_E4M3_MAX / block_absmax_val; // Note that we don't use inverted scales so we can match FBGemm impl.
if (can_vectorize) { if (can_vectorize) {
scaled_fp8_conversion_vec(token_output, token_input, inverted_scale, scaled_fp8_conversion_vec<scalar_t, false>(
hidden_size, tid, blockDim.x); token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
} else { } else {
for (int i = tid; i < hidden_size; i += blockDim.x) { for (int i = tid; i < hidden_size; i += blockDim.x) {
token_output[i] = scaled_fp8_conversion(token_input[i], inverted_scale); token_output[i] = scaled_fp8_conversion<false>(
static_cast<float>(token_input[i]), token_scale);
} }
} }
} }
@ -246,9 +266,10 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
}); });
} }
void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, // [..., d] void dynamic_per_token_scaled_fp8_quant(
torch::Tensor const& input, // [..., d] torch::Tensor& out, // [..., d]
torch::Tensor& scales) { torch::Tensor const& input, // [..., d]
torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) {
TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(out.is_contiguous());
@ -264,6 +285,8 @@ void dynamic_per_token_scaled_fp8_quant(torch::Tensor& out, // [..., d]
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t> vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
<<<grid, block, 0, stream>>>( <<<grid, block, 0, stream>>>(
out.data_ptr<c10::Float8_e4m3fn>(), scales.data_ptr<float>(), out.data_ptr<c10::Float8_e4m3fn>(), scales.data_ptr<float>(),
input.data_ptr<scalar_t>(), hidden_size); input.data_ptr<scalar_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
hidden_size);
}); });
} }

View File

@ -188,7 +188,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Compute dynamic-per-token FP8 quantized tensor and scaling factor. // Compute dynamic-per-token FP8 quantized tensor and scaling factor.
ops.def( ops.def(
"dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! " "dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! "
"scale) -> " "scale, Tensor? scale_ub) -> "
"()"); "()");
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA, ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
&dynamic_per_token_scaled_fp8_quant); &dynamic_per_token_scaled_fp8_quant);

View File

@ -1,4 +1,4 @@
from typing import Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
@ -7,13 +7,19 @@ def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
return torch.as_tensor(x, dtype=torch.float32, device='cuda') return torch.as_tensor(x, dtype=torch.float32, device='cuda')
def ref_dynamic_per_token_quant(x: torch.tensor, def ref_dynamic_per_token_quant(x: torch.tensor,
quant_dtype: torch.dtype) \ quant_dtype: torch.dtype,
scale_ub: Optional[torch.tensor] = None) \
-> Tuple[torch.tensor, torch.tensor]: -> Tuple[torch.tensor, torch.tensor]:
assert quant_dtype in [torch.int8, torch.float8_e4m3fn] assert quant_dtype in [torch.int8, torch.float8_e4m3fn]
if scale_ub is not None:
assert quant_dtype == torch.float8_e4m3fn
qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \ qtype_traits = torch.iinfo(quant_dtype) if quant_dtype == torch.int8 \
else torch.finfo(quant_dtype) else torch.finfo(quant_dtype)
qtype_max = as_float32_tensor(qtype_traits.max) qtype_max = as_float32_tensor(qtype_traits.max)
s_1 = as_float32_tensor(1.0)
s_512 = as_float32_tensor(512.0)
# For fp8, in order to match the cuda kernel output, we have to do exactly # For fp8, in order to match the cuda kernel output, we have to do exactly
# the same operations as in the corresponding fp8 kernel to prevent # the same operations as in the corresponding fp8 kernel to prevent
@ -22,14 +28,24 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
# Compute scales # Compute scales
x_token_max, _ = x.abs().max(dim=-1) x_token_max, _ = x.abs().max(dim=-1)
x_token_max = as_float32_tensor(x_token_max) x_token_max = as_float32_tensor(x_token_max)
if scale_ub is not None:
x_token_max = x_token_max.clamp(max=scale_ub)
scales = (x_token_max / qtype_max)[:, None] scales = (x_token_max / qtype_max)[:, None]
# Quant # Quant
iscales = (qtype_max / x_token_max)[:, None] if quant_dtype == torch.int8:
torch_out = as_float32_tensor(x) * iscales iscales = as_float32_tensor(s_1 / scales)
torch_out = torch_out.round() if quant_dtype == torch.int8 else torch_out torch_out = as_float32_tensor(x) * iscales
torch_out = torch_out.clamp(qtype_traits.min, torch_out = torch_out.round()
qtype_traits.max).to(quant_dtype) torch_out = torch_out.clamp(qtype_traits.min,
qtype_traits.max).to(quant_dtype)
else:
assert quant_dtype == torch.float8_e4m3fn
min_scaling_factor = s_1 / (qtype_max * s_512)
scales = scales.clamp(min=min_scaling_factor)
torch_out = as_float32_tensor(x) / scales
torch_out = torch_out.clamp(qtype_traits.min,
qtype_traits.max).to(quant_dtype)
return torch_out, scales return torch_out, scales

View File

@ -10,24 +10,31 @@ HIDDEN_SIZES = [1, 2, 3, 4, 16, 67, 768, 2048, 5120, 5137, 8192,
8193] # Arbitrary values for testing 8193] # Arbitrary values for testing
HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases HIDDEN_SIZES += list(range(1024, 1033)) # vectorized conversion edge cases
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
SCALE_UBS = [True, False]
SEEDS = [0] SEEDS = [0]
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("scale_ub", SCALE_UBS)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode() @torch.inference_mode()
def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int, def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None: dtype: torch.dtype, scale_ub: bool,
seed: int) -> None:
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, x = torch.rand(num_tokens, hidden_size, dtype=dtype,
device="cuda") + 1e-6 # avoid nans device="cuda") + 1e-6 # avoid nans
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn) scale_ub = torch.mean(x).to(dtype=torch.float32, device='cuda') \
if scale_ub else None
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.float8_e4m3fn,
scale_ub)
ops_out, ops_scales = ops.scaled_fp8_quant(x, ops_out, ops_scales = ops.scaled_fp8_quant(x,
scale_ub=scale_ub,
use_per_token_if_dynamic=True) use_per_token_if_dynamic=True)
assert torch.allclose(ref_scales, ops_scales) assert torch.allclose(ref_scales, ops_scales)

View File

@ -300,6 +300,7 @@ def scaled_fp8_quant(
input: torch.Tensor, input: torch.Tensor,
scale: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None,
batch_dim_padding: Optional[int] = None, batch_dim_padding: Optional[int] = None,
scale_ub: Optional[torch.Tensor] = None,
use_per_token_if_dynamic: bool = False, use_per_token_if_dynamic: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
@ -336,7 +337,7 @@ def scaled_fp8_quant(
device=input.device, device=input.device,
dtype=torch.float32) dtype=torch.float32)
torch.ops._C.dynamic_per_token_scaled_fp8_quant( torch.ops._C.dynamic_per_token_scaled_fp8_quant(
output, input, scale) output, input, scale, scale_ub)
else: else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32) scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)