[Kernel] AQ AZP 3/4: Asymmetric quantization kernels (#7270)
This commit is contained in:
parent
781e3b9a42
commit
5d73ae49d6
@ -257,11 +257,13 @@ void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major
|
||||
// static-per-tensor quantization.
|
||||
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
const torch::Tensor& input, // [..., hidden_size]
|
||||
const torch::Tensor& scale) {
|
||||
const torch::Tensor& scale,
|
||||
c10::optional<torch::Tensor> const& azp) {
|
||||
CPU_KERNEL_GUARD_IN(static_scaled_int8_quant)
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(scale.numel() == 1);
|
||||
TORCH_CHECK(!azp.has_value(), "Zero point is not supported on CPU.");
|
||||
|
||||
const int hidden_size = input.size(-1);
|
||||
const int num_tokens = input.numel() / hidden_size;
|
||||
@ -277,11 +279,12 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
void dynamic_scaled_int8_quant(
|
||||
torch::Tensor& out, // [..., hidden_size]
|
||||
const torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& scale // [..., 1]
|
||||
) {
|
||||
torch::Tensor& scale, // [..., 1]
|
||||
c10::optional<torch::Tensor> const& azp) {
|
||||
CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant)
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(!azp.has_value(), "Zero point is not supported on CPU.");
|
||||
|
||||
int const hidden_size = input.size(-1);
|
||||
int const num_tokens = input.numel() / hidden_size;
|
||||
|
@ -94,13 +94,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
#ifdef __AVX512F__
|
||||
// Compute int8 quantized tensor for given scaling factor.
|
||||
ops.def(
|
||||
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> "
|
||||
"()");
|
||||
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
|
||||
"Tensor? azp) -> ()");
|
||||
ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant);
|
||||
|
||||
// Compute int8 quantized tensor and scaling factor
|
||||
ops.def(
|
||||
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
|
||||
"()");
|
||||
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
|
||||
"Tensor!? azp) -> ()");
|
||||
ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
|
||||
&dynamic_scaled_int8_quant);
|
||||
// W8A8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||
|
@ -184,10 +184,12 @@ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
|
||||
#endif
|
||||
|
||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
torch::Tensor const& scale);
|
||||
torch::Tensor const& scale,
|
||||
c10::optional<torch::Tensor> const& azp);
|
||||
|
||||
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
torch::Tensor& scales);
|
||||
torch::Tensor& scales,
|
||||
c10::optional<torch::Tensor> const& azp);
|
||||
|
||||
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
||||
torch::Tensor b_gptq_qzeros,
|
||||
|
@ -14,12 +14,17 @@
|
||||
|
||||
static inline __device__ int8_t float_to_int8_rn(float x) {
|
||||
#ifdef USE_ROCM
|
||||
static const float i8_min =
|
||||
static constexpr auto i8_min =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||
static const float i8_max =
|
||||
static constexpr auto i8_max =
|
||||
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||
// round
|
||||
|
||||
// To match the rounding mode of CUDA, we use nearbyint.
|
||||
// It uses the current rounding mode, which is always FE_TONEAREST on HIP.
|
||||
// If that changes in the future, we may need to set the rounding mode
|
||||
// explicitly, either at runtime or compile time.
|
||||
float dst = std::nearbyint(x);
|
||||
|
||||
// saturate
|
||||
dst = std::clamp(dst, i8_min, i8_max);
|
||||
return static_cast<int8_t>(dst);
|
||||
@ -31,6 +36,59 @@ static inline __device__ int8_t float_to_int8_rn(float x) {
|
||||
#endif
|
||||
}
|
||||
|
||||
static inline __device__ int32_t float_to_int32_rn(float x) {
|
||||
#ifdef USE_ROCM
|
||||
// int32_max is not exactly representable as float.
|
||||
// Therefore, we need to be careful and manually return int32_max on overflow.
|
||||
// For symmetry, we also do the same for int32_min, even though it is exactly
|
||||
// representable as float and the conversion should be exact.
|
||||
static constexpr auto i32_min = std::numeric_limits<int32_t>::min();
|
||||
static constexpr auto i32_min_f = static_cast<float>(i32_min);
|
||||
static constexpr auto i32_max = std::numeric_limits<int32_t>::max();
|
||||
static constexpr auto i32_max_f = static_cast<float>(i32_max);
|
||||
|
||||
// To match the rounding mode of CUDA, we use nearbyint.
|
||||
// It uses the current rounding mode, which is always FE_TONEAREST on HIP.
|
||||
// If that changes in the future, we may need to set the rounding mode
|
||||
// explicitly, either at runtime or compile time.
|
||||
float dst = std::nearbyint(x);
|
||||
|
||||
// saturate on the higher end.
|
||||
if (dst >= i32_max_f) {
|
||||
return i32_max;
|
||||
}
|
||||
// saturate on the lower end.
|
||||
if (dst <= i32_min_f) {
|
||||
return i32_min;
|
||||
}
|
||||
|
||||
return static_cast<int32_t>(dst);
|
||||
#else
|
||||
// CUDA path
|
||||
uint32_t dst;
|
||||
asm volatile("cvt.rni.sat.s32.f32 %0, %1;" : "=r"(dst) : "f"(x));
|
||||
return reinterpret_cast<const int32_t&>(dst);
|
||||
#endif
|
||||
}
|
||||
|
||||
static inline __device__ int8_t int32_to_int8(int32_t x) {
|
||||
#ifdef USE_ROCM
|
||||
static constexpr auto i8_min =
|
||||
static_cast<int32_t>(std::numeric_limits<int8_t>::min());
|
||||
static constexpr auto i8_max =
|
||||
static_cast<int32_t>(std::numeric_limits<int8_t>::max());
|
||||
|
||||
// saturate
|
||||
int32_t dst = std::clamp(x, i8_min, i8_max);
|
||||
return static_cast<int8_t>(dst);
|
||||
#else
|
||||
// CUDA path
|
||||
uint32_t dst;
|
||||
asm volatile("cvt.sat.s8.s32 %0, %1;" : "=r"(dst) : "r"(x));
|
||||
return reinterpret_cast<const int8_t&>(dst);
|
||||
#endif
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <typename scalar_t, typename scale_type>
|
||||
@ -47,6 +105,23 @@ __global__ void static_scaled_int8_quant_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename scale_type, typename azp_type>
|
||||
__global__ void static_scaled_int8_azp_quant_kernel(
|
||||
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
|
||||
scale_type const* scale_ptr, azp_type const* azp_ptr,
|
||||
const int hidden_size) {
|
||||
int const tid = threadIdx.x;
|
||||
int const token_idx = blockIdx.x;
|
||||
scale_type const scale = *scale_ptr;
|
||||
azp_type const azp = *azp_ptr;
|
||||
|
||||
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||
auto const val = static_cast<float>(input[token_idx * hidden_size + i]);
|
||||
auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp);
|
||||
out[token_idx * hidden_size + i] = quant_val;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename scale_type>
|
||||
__global__ void dynamic_scaled_int8_quant_kernel(
|
||||
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
|
||||
@ -80,14 +155,68 @@ __global__ void dynamic_scaled_int8_quant_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename scale_type, typename azp_type>
|
||||
__global__ void dynamic_scaled_int8_azp_quant_kernel(
|
||||
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
|
||||
scale_type* scale, azp_type* azp, const int hidden_size) {
|
||||
int const token_idx = blockIdx.x;
|
||||
|
||||
// Scan for the min and max value for this token
|
||||
float max_val = std::numeric_limits<float>::min();
|
||||
float min_val = std::numeric_limits<float>::max();
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
auto val = static_cast<float>(input[token_idx * hidden_size + i]);
|
||||
max_val = std::max(max_val, val);
|
||||
min_val = std::min(min_val, val);
|
||||
}
|
||||
|
||||
// Reduce the max and min values across the block
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStorage;
|
||||
max_val = BlockReduce(reduceStorage).Reduce(max_val, cub::Max{}, blockDim.x);
|
||||
__syncthreads(); // Make sure min doesn't mess with max shared memory
|
||||
min_val = BlockReduce(reduceStorage).Reduce(min_val, cub::Min{}, blockDim.x);
|
||||
|
||||
__shared__ scale_type scale_sh;
|
||||
__shared__ azp_type azp_sh;
|
||||
|
||||
// Compute the scale and zero point and store them, only on the first thread
|
||||
if (threadIdx.x == 0) {
|
||||
float const scale_val = (max_val - min_val) / 255.0f;
|
||||
// Use rounding to even (same as torch.round)
|
||||
auto const azp_float = std::nearbyint(-128.0f - min_val / scale_val);
|
||||
auto const azp_val = static_cast<azp_type>(azp_float);
|
||||
|
||||
// Store the scale and azp into shared and global
|
||||
scale[token_idx] = scale_sh = scale_val;
|
||||
azp[token_idx] = azp_sh = azp_val;
|
||||
}
|
||||
|
||||
// Wait for the scale and azp to be computed
|
||||
__syncthreads();
|
||||
|
||||
float const scale_val = scale_sh;
|
||||
azp_type const azp_val = azp_sh;
|
||||
|
||||
// Quantize the values
|
||||
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||
auto const val = static_cast<float>(input[token_idx * hidden_size + i]);
|
||||
auto const quant_val =
|
||||
int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val);
|
||||
out[token_idx * hidden_size + i] = quant_val;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor const& input, // [..., hidden_size]
|
||||
torch::Tensor const& scale) {
|
||||
torch::Tensor const& scale,
|
||||
c10::optional<torch::Tensor> const& azp) {
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(scale.numel() == 1);
|
||||
TORCH_CHECK(!azp || azp->numel() == 1);
|
||||
|
||||
int const hidden_size = input.size(-1);
|
||||
int const num_tokens = input.numel() / hidden_size;
|
||||
@ -96,19 +225,29 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
|
||||
vllm::static_scaled_int8_quant_kernel<scalar_t, float>
|
||||
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
|
||||
out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), hidden_size);
|
||||
if (!azp) {
|
||||
vllm::static_scaled_int8_quant_kernel<scalar_t, float>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), hidden_size);
|
||||
} else {
|
||||
vllm::static_scaled_int8_azp_quant_kernel<scalar_t, float, int32_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scale.data_ptr<float>(), azp->data_ptr<int32_t>(),
|
||||
hidden_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void dynamic_scaled_int8_quant(
|
||||
torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor const& input, // [..., hidden_size]
|
||||
torch::Tensor& scales) {
|
||||
torch::Tensor& scales, c10::optional<torch::Tensor> const& azp) {
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(scales.is_contiguous());
|
||||
TORCH_CHECK(!azp || azp->is_contiguous());
|
||||
|
||||
int const hidden_size = input.size(-1);
|
||||
int const num_tokens = input.numel() / hidden_size;
|
||||
@ -117,9 +256,17 @@ void dynamic_scaled_int8_quant(
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
|
||||
vllm::dynamic_scaled_int8_quant_kernel<scalar_t, float>
|
||||
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
|
||||
out.data_ptr<int8_t>(),
|
||||
scales.data_ptr<float>(), hidden_size);
|
||||
if (!azp) {
|
||||
vllm::dynamic_scaled_int8_quant_kernel<scalar_t, float>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scales.data_ptr<float>(), hidden_size);
|
||||
} else {
|
||||
vllm::dynamic_scaled_int8_azp_quant_kernel<scalar_t, float, int32_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
|
||||
scales.data_ptr<float>(), azp->data_ptr<int32_t>(),
|
||||
hidden_size);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -336,14 +336,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
|
||||
// Compute int8 quantized tensor for given scaling factor.
|
||||
ops.def(
|
||||
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale) -> "
|
||||
"()");
|
||||
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
|
||||
"Tensor? azp) -> ()");
|
||||
ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant);
|
||||
|
||||
// Compute int8 quantized tensor and scaling factor
|
||||
ops.def(
|
||||
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
|
||||
"()");
|
||||
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
|
||||
"Tensor!? azp) -> ()");
|
||||
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
|
||||
&dynamic_scaled_int8_quant);
|
||||
}
|
||||
|
@ -13,14 +13,28 @@ SEEDS = [0]
|
||||
SCALE = [0.1, 0.5, 0.8, 1.2, 2.1]
|
||||
|
||||
|
||||
def opcheck_int8_quant(output, input, scale=None):
|
||||
if scale is not None:
|
||||
opcheck(torch.ops._C.static_scaled_int8_quant, (output, input, scale))
|
||||
def opcheck_int8_quant_static(output, input, scale, azp=None):
|
||||
if azp is None:
|
||||
opcheck(torch.ops._C.static_scaled_int8_quant,
|
||||
(output, input, scale, None))
|
||||
else:
|
||||
scale = torch.empty((input.numel() // input.shape[-1], 1),
|
||||
device=input.device,
|
||||
dtype=torch.float32)
|
||||
opcheck(torch.ops._C.dynamic_scaled_int8_quant, (output, input, scale))
|
||||
opcheck(torch.ops._C.static_scaled_int8_quant,
|
||||
(output, input, scale, azp))
|
||||
|
||||
|
||||
def opcheck_int8_quant_dynamic(output, input, symmetric=True):
|
||||
scale = torch.empty((input.numel() // input.shape[-1], 1),
|
||||
device=input.device,
|
||||
dtype=torch.float32)
|
||||
if symmetric:
|
||||
opcheck(torch.ops._C.dynamic_scaled_int8_quant,
|
||||
(output, input, scale, None))
|
||||
else:
|
||||
azp = torch.empty((input.numel() // input.shape[-1], 1),
|
||||
device=input.device,
|
||||
dtype=torch.int32)
|
||||
opcheck(torch.ops._C.dynamic_scaled_int8_quant,
|
||||
(output, input, scale, azp))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@ -38,14 +52,56 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||
# reference
|
||||
ref_out, ref_scales = ref_dynamic_per_token_quant(x, torch.int8)
|
||||
# kernel
|
||||
ops_out, ops_scales = scaled_int8_quant(x)
|
||||
ops_out, ops_scales, _ = scaled_int8_quant(x)
|
||||
|
||||
torch.testing.assert_close(ops_scales, ref_scales)
|
||||
torch.testing.assert_close(
|
||||
ops_out, ref_out, atol=1,
|
||||
rtol=0.0) # big atol to account for rounding errors
|
||||
# big atol to account for rounding errors
|
||||
torch.testing.assert_close(ops_out, ref_out, atol=1, rtol=0.0)
|
||||
|
||||
opcheck_int8_quant(ops_out, x)
|
||||
opcheck_int8_quant_dynamic(ops_out, x)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@torch.inference_mode()
|
||||
def test_dynamic_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
|
||||
device="cuda") * 1000 - 300
|
||||
|
||||
x_token_max, _ = x.to(dtype=torch.float32).max(dim=1, keepdim=True)
|
||||
x_token_min, _ = x.to(dtype=torch.float32).min(dim=1, keepdim=True)
|
||||
|
||||
# calculate scale and azp, and adjust the range
|
||||
scales = (x_token_max - x_token_min) / torch.tensor(255.0)
|
||||
azps = torch.round(torch.tensor(-128.0) - x_token_min / scales).to(
|
||||
torch.int32)
|
||||
|
||||
torch_out = ((x / scales).round() + azps).clamp(
|
||||
int8_traits.min, int8_traits.max).to(torch.int8)
|
||||
assert torch_out.min() >= int8_traits.min and torch_out.max(
|
||||
) <= int8_traits.max
|
||||
|
||||
ops_out = torch.empty_like(x, dtype=torch.int8)
|
||||
scales_out = torch.empty_like(scales, dtype=torch.float32)
|
||||
azp_out = torch.empty_like(azps, dtype=torch.int32)
|
||||
torch.ops._C.dynamic_scaled_int8_quant(ops_out, x, scales_out, azp_out)
|
||||
|
||||
if (not torch.allclose(scales_out, scales)):
|
||||
print(torch.argmax(torch.abs(scales_out - scales)))
|
||||
torch.testing.assert_close(scales_out, scales)
|
||||
# big atol to account for rounding errors
|
||||
torch.testing.assert_close(azp_out, azps, atol=1, rtol=0.0)
|
||||
# if AZP is off by 1, after rounding-to-even, the output may be off by 2
|
||||
torch.testing.assert_close(ops_out, torch_out, atol=2, rtol=0.0)
|
||||
|
||||
opcheck_int8_quant_dynamic(ops_out, x, False)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@ -62,14 +118,76 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
|
||||
scale = torch.tensor([scale], dtype=torch.float32, device="cuda")
|
||||
scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda")
|
||||
|
||||
out1 = (x / scale).round().clamp(int8_traits.min,
|
||||
int8_traits.max).to(torch.int8)
|
||||
out2, _ = scaled_int8_quant(x, scale)
|
||||
out1 = (x / scale_arg).round().clamp(int8_traits.min,
|
||||
int8_traits.max).to(torch.int8)
|
||||
out2, _, _ = scaled_int8_quant(x, scale_arg)
|
||||
|
||||
torch.testing.assert_close(
|
||||
out1, out2, atol=1,
|
||||
rtol=0.0) # big atol to account for rounding errors
|
||||
# big atol to account for rounding errors
|
||||
torch.testing.assert_close(out1, out2, atol=1, rtol=0.0)
|
||||
|
||||
opcheck_int8_quant(out2, x, scale)
|
||||
opcheck_int8_quant_static(out2, x, scale_arg)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("scale", SCALE[2:]) # Reduce test time
|
||||
@pytest.mark.parametrize("azp", [-255, 54])
|
||||
@torch.inference_mode()
|
||||
def test_static_scaled_int8_azp_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int,
|
||||
scale: float, azp: int) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype,
|
||||
device="cuda") * 1000 - 300
|
||||
|
||||
out1 = ((x / scale).round() + azp).clamp(int8_traits.min,
|
||||
int8_traits.max).to(torch.int8)
|
||||
out2 = torch.empty_like(x, dtype=torch.int8)
|
||||
scale_arg = torch.tensor([scale], dtype=torch.float32, device="cuda")
|
||||
azp_arg = torch.tensor([azp], dtype=torch.int32, device="cuda")
|
||||
|
||||
torch.ops._C.static_scaled_int8_quant(out2, x, scale_arg, azp_arg)
|
||||
|
||||
# big atol to account for rounding errors
|
||||
torch.testing.assert_close(out1, out2, atol=1, rtol=0.0)
|
||||
|
||||
opcheck_int8_quant_static(out2, x, scale_arg, azp_arg)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_max", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_static_scaled_int8_azp_quant_saturating_cast(is_max: bool) -> None:
|
||||
# Test that the saturating cast works correctly for values near i32 max/min
|
||||
|
||||
from numpy import inf, nextafter
|
||||
|
||||
int32_traits = torch.iinfo(torch.int32)
|
||||
val = float(int32_traits.max if is_max else int32_traits.min)
|
||||
|
||||
x_vals = [[
|
||||
nextafter(val, inf), val + 1, val, val - 1,
|
||||
nextafter(val, -inf)
|
||||
]]
|
||||
x = torch.tensor(x_vals, dtype=torch.float32, device="cuda")
|
||||
|
||||
# The calculation in the kernel is: cast<int8>(cast<int32>(x / scale) + azp)
|
||||
# where cast<T> is a saturating cast to type T.
|
||||
# Scale is set to 1.0 so that the input values are the ones that are cast.
|
||||
# AZP is set to 0 to make sure the int8 saturating cast is tested as well.
|
||||
scale = torch.scalar_tensor(1.0, dtype=torch.float32, device="cuda")
|
||||
azp = torch.scalar_tensor(0, dtype=torch.int32, device="cuda")
|
||||
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
val_i8 = int8_traits.max if is_max else int8_traits.min
|
||||
expected = torch.full((1, 5), val_i8, dtype=torch.int8, device="cuda")
|
||||
|
||||
out = torch.empty_like(expected)
|
||||
torch.ops._C.static_scaled_int8_quant(out, x, scale, azp)
|
||||
torch.testing.assert_close(expected, out, atol=0, rtol=0)
|
||||
|
@ -684,32 +684,43 @@ def scaled_fp8_quant(
|
||||
|
||||
# int8
|
||||
def scaled_int8_quant(
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
azp: Optional[torch.Tensor] = None,
|
||||
symmetric: bool = True
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Quantize the input tensor to int8 and return the quantized tensor and scale.
|
||||
Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp.
|
||||
|
||||
Args:
|
||||
input: The input tensor to be quantized to int8.
|
||||
scale: Optional scaling factor for the int8 quantization.
|
||||
When not provided, we invoke dynamic-per-token quantization.
|
||||
azp: Optional zero-point for the int8 quantization.
|
||||
Must be provided for asymmetric quantization if `scale` is provided.
|
||||
symmetric: Whether to use symmetric quantization (scale only, azp ignored).
|
||||
|
||||
Returns:
|
||||
Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales.
|
||||
Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp.
|
||||
"""
|
||||
output = torch.empty_like(input, dtype=torch.int8)
|
||||
if scale is not None:
|
||||
# static-per-tensor quantization.
|
||||
torch.ops._C.static_scaled_int8_quant(output, input, scale)
|
||||
return output, scale
|
||||
assert symmetric == (
|
||||
azp is
|
||||
None), "azp must only be provided for asymmetric quantization."
|
||||
torch.ops._C.static_scaled_int8_quant(output, input, scale, azp)
|
||||
return output, scale, None
|
||||
|
||||
# dynamic-per-token quantization.
|
||||
input_scales = torch.empty((input.numel() // input.shape[-1], 1),
|
||||
device=input.device,
|
||||
dtype=torch.float32)
|
||||
torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales)
|
||||
return output, input_scales
|
||||
input_azp = None if symmetric else torch.empty_like(input_scales,
|
||||
dtype=torch.int32)
|
||||
torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales,
|
||||
input_azp)
|
||||
return output, input_scales, input_azp
|
||||
|
||||
|
||||
# qqq ops
|
||||
|
@ -260,7 +260,7 @@ class QQQLinearMethod(LinearMethodBase):
|
||||
size_k = x_2d.shape[1]
|
||||
size_n = s_ch.shape[1]
|
||||
|
||||
x_int8, s_tok = ops.scaled_int8_quant(x_2d)
|
||||
x_int8, s_tok, _ = ops.scaled_int8_quant(x_2d)
|
||||
|
||||
output_2d = ops.marlin_qqq_gemm(x_int8, qweight, s_tok, s_ch, s_group,
|
||||
workspace, size_m, size_n, size_k)
|
||||
|
@ -188,7 +188,7 @@ def apply_int8_linear(
|
||||
# ops.scaled_int8_quant supports both dynamic and static quant.
|
||||
# * dynamic, layer.input_scale is None and x_scale computed from x.
|
||||
# * static, layer.input_scale is scalar and x_scale is input_scale.
|
||||
x_q, x_scale = ops.scaled_int8_quant(input, input_scale)
|
||||
x_q, x_scale, _ = ops.scaled_int8_quant(input, input_scale)
|
||||
|
||||
return ops.cutlass_scaled_mm(x_q,
|
||||
weight,
|
||||
|
Loading…
x
Reference in New Issue
Block a user