2024-05-23 17:29:18 -04:00
|
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
|
|
#include <torch/extension.h>
|
|
|
|
#include <cmath>
|
|
|
|
|
|
|
|
#include "../../dispatch_utils.h"
|
2024-06-07 12:36:26 -04:00
|
|
|
#include "../../reduction_utils.cuh"
|
2024-05-23 17:29:18 -04:00
|
|
|
|
|
|
|
static inline __device__ int8_t float_to_int8_rn(float x) {
|
|
|
|
#ifdef USE_ROCM
|
|
|
|
static const float i8_min =
|
|
|
|
static_cast<float>(std::numeric_limits<int8_t>::min());
|
|
|
|
static const float i8_max =
|
|
|
|
static_cast<float>(std::numeric_limits<int8_t>::max());
|
|
|
|
// round
|
|
|
|
float dst = std::nearbyint(x);
|
|
|
|
// saturate
|
|
|
|
dst = std::clamp(dst, i8_min, i8_max);
|
|
|
|
return static_cast<int8_t>(dst);
|
|
|
|
#else
|
|
|
|
// CUDA path
|
|
|
|
uint32_t dst;
|
|
|
|
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
|
|
|
|
return reinterpret_cast<const int8_t&>(dst);
|
|
|
|
#endif
|
|
|
|
}
|
|
|
|
|
|
|
|
namespace vllm {
|
|
|
|
|
|
|
|
template <typename scalar_t, typename scale_type>
|
|
|
|
__global__ void static_scaled_int8_quant_kernel(
|
2024-06-07 12:36:26 -04:00
|
|
|
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
|
|
|
|
scale_type const* scale_ptr, const int hidden_size) {
|
|
|
|
int const tid = threadIdx.x;
|
|
|
|
int const token_idx = blockIdx.x;
|
|
|
|
scale_type const scale = *scale_ptr;
|
2024-05-23 17:29:18 -04:00
|
|
|
|
|
|
|
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
2024-06-07 12:36:26 -04:00
|
|
|
out[token_idx * hidden_size + i] = float_to_int8_rn(
|
|
|
|
static_cast<float>(input[token_idx * hidden_size + i]) / scale);
|
2024-05-23 17:29:18 -04:00
|
|
|
}
|
|
|
|
}
|
2024-06-07 12:36:26 -04:00
|
|
|
|
|
|
|
template <typename scalar_t, typename scale_type>
|
|
|
|
__global__ void dynamic_scaled_int8_quant_kernel(
|
|
|
|
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
|
|
|
|
scale_type* scale, const int hidden_size) {
|
|
|
|
int const tid = threadIdx.x;
|
|
|
|
int const token_idx = blockIdx.x;
|
|
|
|
float absmax_val = 0.0f;
|
|
|
|
float const zero = 0.0f;
|
|
|
|
|
|
|
|
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
|
|
|
float val = static_cast<float>(input[token_idx * hidden_size + i]);
|
|
|
|
val = val > zero ? val : -val;
|
|
|
|
absmax_val = val > absmax_val ? val : absmax_val;
|
|
|
|
}
|
|
|
|
|
|
|
|
float const block_absmax_val_maybe = blockReduceMax(absmax_val);
|
|
|
|
__shared__ float block_absmax_val;
|
|
|
|
if (tid == 0) {
|
|
|
|
block_absmax_val = block_absmax_val_maybe;
|
|
|
|
scale[token_idx] = block_absmax_val / 127.0f;
|
|
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
float const tmp_scale = 127.0f / block_absmax_val;
|
|
|
|
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
|
|
|
out[token_idx * hidden_size + i] = float_to_int8_rn(
|
|
|
|
static_cast<float>(input[token_idx * hidden_size + i]) * tmp_scale);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-05-23 17:29:18 -04:00
|
|
|
} // namespace vllm
|
|
|
|
|
2024-06-03 12:52:30 -04:00
|
|
|
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
|
|
|
torch::Tensor const& input, // [..., hidden_size]
|
|
|
|
torch::Tensor const& scale) {
|
2024-05-23 17:29:18 -04:00
|
|
|
TORCH_CHECK(input.is_contiguous());
|
|
|
|
TORCH_CHECK(out.is_contiguous());
|
2024-06-03 12:52:30 -04:00
|
|
|
TORCH_CHECK(scale.numel() == 1);
|
|
|
|
|
2024-06-07 12:36:26 -04:00
|
|
|
int const hidden_size = input.size(-1);
|
|
|
|
int const num_tokens = input.numel() / hidden_size;
|
|
|
|
dim3 const grid(num_tokens);
|
|
|
|
dim3 const block(std::min(hidden_size, 1024));
|
2024-05-23 17:29:18 -04:00
|
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
VLLM_DISPATCH_FLOATING_TYPES(
|
|
|
|
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
|
|
|
|
vllm::static_scaled_int8_quant_kernel<scalar_t, float>
|
|
|
|
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
|
2024-06-03 12:52:30 -04:00
|
|
|
out.data_ptr<int8_t>(),
|
|
|
|
scale.data_ptr<float>(), hidden_size);
|
2024-05-23 17:29:18 -04:00
|
|
|
});
|
|
|
|
}
|
2024-06-07 12:36:26 -04:00
|
|
|
|
|
|
|
void dynamic_scaled_int8_quant(
|
|
|
|
torch::Tensor& out, // [..., hidden_size]
|
|
|
|
torch::Tensor const& input, // [..., hidden_size]
|
|
|
|
torch::Tensor& scales) {
|
|
|
|
TORCH_CHECK(input.is_contiguous());
|
|
|
|
TORCH_CHECK(out.is_contiguous());
|
|
|
|
|
|
|
|
int const hidden_size = input.size(-1);
|
|
|
|
int const num_tokens = input.numel() / hidden_size;
|
|
|
|
dim3 const grid(num_tokens);
|
|
|
|
dim3 const block(std::min(hidden_size, 1024));
|
|
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
VLLM_DISPATCH_FLOATING_TYPES(
|
|
|
|
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
|
|
|
|
vllm::dynamic_scaled_int8_quant_kernel<scalar_t, float>
|
|
|
|
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
|
|
|
|
out.data_ptr<int8_t>(),
|
|
|
|
scales.data_ptr<float>(), hidden_size);
|
|
|
|
});
|
|
|
|
}
|