[Kernel] Dynamic Per-Token Activation Quantization (#5037)
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
parent
dc49fb892c
commit
ca3ea51bde
@ -97,6 +97,9 @@ int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
|
|||||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||||
torch::Tensor const& scale);
|
torch::Tensor const& scale);
|
||||||
|
|
||||||
|
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||||
|
torch::Tensor& scales);
|
||||||
|
|
||||||
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||||
torch::Tensor lookup_table);
|
torch::Tensor lookup_table);
|
||||||
|
|
||||||
|
@ -70,6 +70,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
ops.def("static_scaled_int8_quant", &static_scaled_int8_quant,
|
ops.def("static_scaled_int8_quant", &static_scaled_int8_quant,
|
||||||
"Compute int8 quantized tensor for given scaling factor");
|
"Compute int8 quantized tensor for given scaling factor");
|
||||||
|
|
||||||
|
ops.def("dynamic_scaled_int8_quant", &dynamic_scaled_int8_quant,
|
||||||
|
"Compute int8 quantized tensor and scaling factor");
|
||||||
|
|
||||||
// Cache ops
|
// Cache ops
|
||||||
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
|
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
|
||||||
cache_ops.def("swap_blocks", &swap_blocks,
|
cache_ops.def("swap_blocks", &swap_blocks,
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
#include "../../dispatch_utils.h"
|
#include "../../dispatch_utils.h"
|
||||||
|
#include "../../reduction_utils.cuh"
|
||||||
|
|
||||||
static inline __device__ int8_t float_to_int8_rn(float x) {
|
static inline __device__ int8_t float_to_int8_rn(float x) {
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
@ -27,17 +28,48 @@ namespace vllm {
|
|||||||
|
|
||||||
template <typename scalar_t, typename scale_type>
|
template <typename scalar_t, typename scale_type>
|
||||||
__global__ void static_scaled_int8_quant_kernel(
|
__global__ void static_scaled_int8_quant_kernel(
|
||||||
const scalar_t* __restrict__ input, int8_t* __restrict__ out,
|
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
|
||||||
const scale_type* scale_ptr, const int hidden_size) {
|
scale_type const* scale_ptr, const int hidden_size) {
|
||||||
const int tid = threadIdx.x;
|
int const tid = threadIdx.x;
|
||||||
const int token_idx = blockIdx.x;
|
int const token_idx = blockIdx.x;
|
||||||
scale_type scale = *scale_ptr;
|
scale_type const scale = *scale_ptr;
|
||||||
|
|
||||||
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||||
out[token_idx * hidden_size + i] =
|
out[token_idx * hidden_size + i] = float_to_int8_rn(
|
||||||
float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale);
|
static_cast<float>(input[token_idx * hidden_size + i]) / scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t, typename scale_type>
|
||||||
|
__global__ void dynamic_scaled_int8_quant_kernel(
|
||||||
|
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
|
||||||
|
scale_type* scale, const int hidden_size) {
|
||||||
|
int const tid = threadIdx.x;
|
||||||
|
int const token_idx = blockIdx.x;
|
||||||
|
float absmax_val = 0.0f;
|
||||||
|
float const zero = 0.0f;
|
||||||
|
|
||||||
|
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||||
|
float val = static_cast<float>(input[token_idx * hidden_size + i]);
|
||||||
|
val = val > zero ? val : -val;
|
||||||
|
absmax_val = val > absmax_val ? val : absmax_val;
|
||||||
|
}
|
||||||
|
|
||||||
|
float const block_absmax_val_maybe = blockReduceMax(absmax_val);
|
||||||
|
__shared__ float block_absmax_val;
|
||||||
|
if (tid == 0) {
|
||||||
|
block_absmax_val = block_absmax_val_maybe;
|
||||||
|
scale[token_idx] = block_absmax_val / 127.0f;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
float const tmp_scale = 127.0f / block_absmax_val;
|
||||||
|
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||||
|
out[token_idx * hidden_size + i] = float_to_int8_rn(
|
||||||
|
static_cast<float>(input[token_idx * hidden_size + i]) * tmp_scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||||
@ -47,10 +79,10 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
|||||||
TORCH_CHECK(out.is_contiguous());
|
TORCH_CHECK(out.is_contiguous());
|
||||||
TORCH_CHECK(scale.numel() == 1);
|
TORCH_CHECK(scale.numel() == 1);
|
||||||
|
|
||||||
int hidden_size = input.size(-1);
|
int const hidden_size = input.size(-1);
|
||||||
int num_tokens = input.numel() / hidden_size;
|
int const num_tokens = input.numel() / hidden_size;
|
||||||
dim3 grid(num_tokens);
|
dim3 const grid(num_tokens);
|
||||||
dim3 block(std::min(hidden_size, 1024));
|
dim3 const block(std::min(hidden_size, 1024));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
|
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
|
||||||
@ -60,3 +92,24 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
|||||||
scale.data_ptr<float>(), hidden_size);
|
scale.data_ptr<float>(), hidden_size);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void dynamic_scaled_int8_quant(
|
||||||
|
torch::Tensor& out, // [..., hidden_size]
|
||||||
|
torch::Tensor const& input, // [..., hidden_size]
|
||||||
|
torch::Tensor& scales) {
|
||||||
|
TORCH_CHECK(input.is_contiguous());
|
||||||
|
TORCH_CHECK(out.is_contiguous());
|
||||||
|
|
||||||
|
int const hidden_size = input.size(-1);
|
||||||
|
int const num_tokens = input.numel() / hidden_size;
|
||||||
|
dim3 const grid(num_tokens);
|
||||||
|
dim3 const block(std::min(hidden_size, 1024));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
|
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
|
||||||
|
vllm::dynamic_scaled_int8_quant_kernel<scalar_t, float>
|
||||||
|
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
|
||||||
|
out.data_ptr<int8_t>(),
|
||||||
|
scales.data_ptr<float>(), hidden_size);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
@ -21,29 +21,47 @@
|
|||||||
#include "cuda_compat.h"
|
#include "cuda_compat.h"
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
template <typename T, int numLanes = WARP_SIZE>
|
|
||||||
__inline__ __device__ T warpReduceSum(T val) {
|
namespace detail {
|
||||||
static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
|
|
||||||
"numLanes is not a positive power of 2!");
|
template <typename T>
|
||||||
static_assert(numLanes <= WARP_SIZE);
|
__inline__ __device__ T _max(T a, T b) {
|
||||||
#pragma unroll
|
return max(a, b);
|
||||||
for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
|
|
||||||
val += VLLM_SHFL_XOR_SYNC(val, mask);
|
|
||||||
return val;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__inline__ __device__ T _sum(T a, T b) {
|
||||||
|
return a + b;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
using ReduceFnType = T (*)(T, T);
|
||||||
|
|
||||||
// Helper function to return the next largest power of 2
|
// Helper function to return the next largest power of 2
|
||||||
static constexpr int _nextPow2(unsigned int num) {
|
static constexpr int _nextPow2(unsigned int num) {
|
||||||
if (num <= 1) return num;
|
if (num <= 1) return num;
|
||||||
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Calculate the sum of all elements in a block */
|
template <typename T, int numLanes = WARP_SIZE>
|
||||||
|
__inline__ __device__ T warpReduce(T val, ReduceFnType<T> fn) {
|
||||||
|
static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
|
||||||
|
"numLanes is not a positive power of 2!");
|
||||||
|
static_assert(numLanes <= WARP_SIZE);
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
|
||||||
|
val = fn(val, VLLM_SHFL_XOR_SYNC(val, mask));
|
||||||
|
|
||||||
|
return val;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, int maxBlockSize = 1024>
|
template <typename T, int maxBlockSize = 1024>
|
||||||
__inline__ __device__ T blockReduceSum(T val) {
|
__inline__ __device__ T blockReduce(T val, ReduceFnType<T> fn) {
|
||||||
static_assert(maxBlockSize <= 1024);
|
static_assert(maxBlockSize <= 1024);
|
||||||
if constexpr (maxBlockSize > WARP_SIZE) {
|
if constexpr (maxBlockSize > WARP_SIZE) {
|
||||||
val = warpReduceSum<T>(val);
|
val = warpReduce<T>(val, fn);
|
||||||
// Calculates max number of lanes that need to participate in the last
|
// Calculates max number of lanes that need to participate in the last
|
||||||
// warpReduce
|
// warpReduce
|
||||||
constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE;
|
constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE;
|
||||||
@ -56,12 +74,22 @@ __inline__ __device__ T blockReduceSum(T val) {
|
|||||||
|
|
||||||
val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane]
|
val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane]
|
||||||
: (T)(0.0f);
|
: (T)(0.0f);
|
||||||
val = warpReduceSum<T, _nextPow2(maxActiveLanes)>(val);
|
val = warpReduce<T, _nextPow2(maxActiveLanes)>(val, fn);
|
||||||
} else {
|
} else {
|
||||||
// A single warpReduce is equal to blockReduce
|
// A single warpReduce is equal to blockReduce
|
||||||
val = warpReduceSum<T, _nextPow2(maxBlockSize)>(val);
|
val = warpReduce<T, _nextPow2(maxBlockSize)>(val, fn);
|
||||||
}
|
}
|
||||||
return val;
|
return val;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, int maxBlockSize = 1024>
|
||||||
|
__inline__ __device__ T blockReduceMax(T val) {
|
||||||
|
return blockReduce<T, maxBlockSize>(val, detail::_max<T>);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int maxBlockSize = 1024>
|
||||||
|
__inline__ __device__ T blockReduceSum(T val) {
|
||||||
|
return blockReduce<T, maxBlockSize>(val, detail::_sum<T>);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
@ -4,7 +4,8 @@ import torch
|
|||||||
from vllm._C import ops
|
from vllm._C import ops
|
||||||
|
|
||||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 8192] # Arbitrary values for testing
|
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192,
|
||||||
|
8193] # Arbitrary values for testing
|
||||||
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
|
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
|
||||||
SEEDS = [0]
|
SEEDS = [0]
|
||||||
SCALE = [0.1, 0.5, 0.8, 1.2, 2.1]
|
SCALE = [0.1, 0.5, 0.8, 1.2, 2.1]
|
||||||
@ -14,17 +15,48 @@ SCALE = [0.1, 0.5, 0.8, 1.2, 2.1]
|
|||||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||||
@pytest.mark.parametrize("dtype", DTYPES)
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
@pytest.mark.parametrize("seed", SEEDS)
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
@pytest.mark.parametrize("scale", SCALE)
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def test_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype,
|
def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||||
seed: int, scale: float) -> None:
|
dtype: torch.dtype, seed: int) -> None:
|
||||||
torch.random.manual_seed(seed)
|
torch.random.manual_seed(seed)
|
||||||
torch.cuda.manual_seed(seed)
|
torch.cuda.manual_seed(seed)
|
||||||
|
int8_traits = torch.iinfo(torch.int8)
|
||||||
|
|
||||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
|
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
|
||||||
|
|
||||||
out1 = (x / scale).round().clamp(
|
x_token_max, _ = x.max(dim=1)
|
||||||
torch.iinfo(torch.int8).min,
|
x_token_max = x_token_max.to(dtype=torch.float32)
|
||||||
torch.iinfo(torch.int8).max).to(torch.int8)
|
scales = (x_token_max / float(127.0))[:, None].to(device="cuda",
|
||||||
|
dtype=torch.float32)
|
||||||
|
torch_out = (x / scales).round().clamp(int8_traits.min,
|
||||||
|
int8_traits.max).to(torch.int8)
|
||||||
|
|
||||||
|
ops_out = torch.empty_like(x, dtype=torch.int8, device="cuda")
|
||||||
|
scales_out = torch.empty_like(scales, dtype=torch.float32, device="cuda")
|
||||||
|
ops.dynamic_scaled_int8_quant(ops_out, x, scales_out)
|
||||||
|
|
||||||
|
assert torch.allclose(scales_out, scales)
|
||||||
|
assert torch.allclose(torch_out, ops_out,
|
||||||
|
atol=1) # big atol to account for rounding errors
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||||
|
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
|
@pytest.mark.parametrize("scale", SCALE)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||||
|
dtype: torch.dtype, seed: int,
|
||||||
|
scale: float) -> None:
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
int8_traits = torch.iinfo(torch.int8)
|
||||||
|
|
||||||
|
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
|
||||||
|
|
||||||
|
out1 = (x / scale).round().clamp(int8_traits.min,
|
||||||
|
int8_traits.max).to(torch.int8)
|
||||||
out2 = torch.empty_like(x, dtype=torch.int8)
|
out2 = torch.empty_like(x, dtype=torch.int8)
|
||||||
scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda")
|
scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda")
|
||||||
|
|
||||||
|
@ -6,7 +6,8 @@ Run `pytest tests/quantization/test_compressed_tensors.py`.
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||||
CompressedTensorsLinearMethod, CompressedTensorsW8A8StaticTensor)
|
CompressedTensorsLinearMethod, CompressedTensorsW8A8DynamicToken,
|
||||||
|
CompressedTensorsW8A8StaticTensor)
|
||||||
|
|
||||||
|
|
||||||
def test_compressed_tensors_w8a8_static_setup(vllm_runner):
|
def test_compressed_tensors_w8a8_static_setup(vllm_runner):
|
||||||
@ -34,3 +35,19 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner):
|
|||||||
assert qkv_proj.weight_scale.shard_splitter is not None
|
assert qkv_proj.weight_scale.shard_splitter is not None
|
||||||
assert qkv_proj.weight_scale.logical_widths is not None
|
assert qkv_proj.weight_scale.logical_widths is not None
|
||||||
assert qkv_proj.input_scale.dtype is torch.float32
|
assert qkv_proj.input_scale.dtype is torch.float32
|
||||||
|
|
||||||
|
|
||||||
|
def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner):
|
||||||
|
model_path = "nm-testing/tinyllama-one-shot-dynamic-test"
|
||||||
|
llm = vllm_runner(model_path,
|
||||||
|
quantization="sparseml",
|
||||||
|
enforce_eager=True,
|
||||||
|
dtype=torch.float16)
|
||||||
|
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model
|
||||||
|
layer = model.model.layers[0]
|
||||||
|
|
||||||
|
qkv_proj = layer.self_attn.qkv_proj
|
||||||
|
|
||||||
|
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
||||||
|
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8DynamicToken)
|
||||||
|
assert qkv_proj.weight.dtype is torch.int8
|
||||||
|
@ -266,21 +266,33 @@ def scaled_fp8_quant(
|
|||||||
|
|
||||||
|
|
||||||
# int8
|
# int8
|
||||||
def static_scaled_int8_quant(input: torch.Tensor,
|
def scaled_int8_quant(
|
||||||
scale: torch.Tensor) -> torch.Tensor:
|
input: torch.Tensor,
|
||||||
|
scale: Optional[torch.Tensor] = None
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Quantize the input tensor to int8 and return the quantized tensor.
|
Quantize the input tensor to int8 and return the quantized tensor and scale.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input: The input tensor to be quantized to int8.
|
input: The input tensor to be quantized to int8.
|
||||||
scale: Scaling factor for the int8 quantization.
|
scale: Optional scaling factor for the int8 quantization.
|
||||||
|
When not provided, we invoke dynamic-per-token quantization.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Output tensor in int8.
|
Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales.
|
||||||
"""
|
"""
|
||||||
q = torch.empty_like(input, dtype=torch.int8)
|
output = torch.empty_like(input, dtype=torch.int8)
|
||||||
vllm_ops.static_scaled_int8_quant(q, input, scale)
|
if scale is not None:
|
||||||
return q
|
# static-per-tensor quantization.
|
||||||
|
vllm_ops.static_scaled_int8_quant(output, input, scale)
|
||||||
|
return output, scale
|
||||||
|
|
||||||
|
# dynamic-per-token quantization.
|
||||||
|
input_scales = torch.empty((input.numel() // input.shape[-1], 1),
|
||||||
|
device=input.device,
|
||||||
|
dtype=torch.float32)
|
||||||
|
vllm_ops.dynamic_scaled_int8_quant(output, input, input_scales)
|
||||||
|
return output, input_scales
|
||||||
|
|
||||||
|
|
||||||
# moe
|
# moe
|
||||||
|
@ -1,12 +1,16 @@
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||||
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
CompressedTensorsScheme, CompressedTensorsW8A8StaticTensor)
|
CompressedTensorsScheme, CompressedTensorsW8A8DynamicToken,
|
||||||
|
CompressedTensorsW8A8StaticTensor)
|
||||||
|
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||||
|
QuantizationArgs, QuantizationStrategy, find_first_name_or_class_match)
|
||||||
|
|
||||||
|
|
||||||
class CompressedTensorsConfig(QuantizationConfig):
|
class CompressedTensorsConfig(QuantizationConfig):
|
||||||
@ -47,10 +51,12 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
targets = quant_config.get("targets")
|
targets = quant_config.get("targets")
|
||||||
for target in targets:
|
for target in targets:
|
||||||
layer_quant_details[target] = {}
|
layer_quant_details[target] = {}
|
||||||
layer_quant_details[target]["weight"] = quant_config.get(
|
layer_quant_details[target][
|
||||||
"weights")
|
"weight"] = QuantizationArgs.parse_obj(
|
||||||
layer_quant_details[target]["input"] = quant_config.get(
|
quant_config.get("weights"))
|
||||||
"input_activations")
|
layer_quant_details[target][
|
||||||
|
"input"] = QuantizationArgs.parse_obj(
|
||||||
|
quant_config.get("input_activations"))
|
||||||
|
|
||||||
return cls(layer_quant_details=layer_quant_details, ignore=ignore)
|
return cls(layer_quant_details=layer_quant_details, ignore=ignore)
|
||||||
|
|
||||||
@ -58,40 +64,46 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
def get_config_filenames(cls) -> List[str]:
|
def get_config_filenames(cls) -> List[str]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def _get_schema(self, weight_quant: Dict, input_quant: Dict):
|
def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
|
||||||
# TODO: Refactor as additional cases are supported
|
input_quant: BaseModel) -> bool:
|
||||||
|
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
|
||||||
|
is_tensor = (weight_quant.strategy == input_quant.strategy ==
|
||||||
|
QuantizationStrategy.TENSOR.value)
|
||||||
|
is_symmetric = weight_quant.symmetric and input_quant.symmetric
|
||||||
|
is_static = not weight_quant.dynamic and not input_quant.dynamic
|
||||||
|
|
||||||
weight_bit = weight_quant.get("num_bits")
|
return is_8_bits and is_tensor and is_symmetric and is_static
|
||||||
input_bit = input_quant.get("num_bits")
|
|
||||||
|
|
||||||
weight_strategy = weight_quant.get("strategy")
|
def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
|
||||||
input_strategy = input_quant.get("strategy")
|
input_quant: BaseModel) -> bool:
|
||||||
|
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
|
||||||
|
is_token_tensor = (weight_quant.strategy
|
||||||
|
== QuantizationStrategy.TENSOR.value) and (
|
||||||
|
input_quant.strategy
|
||||||
|
== QuantizationStrategy.TOKEN.value)
|
||||||
|
is_symmetric = weight_quant.symmetric and input_quant.symmetric
|
||||||
|
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
|
||||||
|
|
||||||
weight_symmetric = weight_quant.get("symmetric")
|
return is_8_bits and is_token_tensor and is_symmetric and is_dynamic
|
||||||
input_symmetric = input_quant.get("symmetric")
|
|
||||||
|
|
||||||
is_8_bits = weight_bit == input_bit == 8
|
def _get_schema(self, weight_quant: BaseModel,
|
||||||
is_tensor = weight_strategy == input_strategy == "tensor"
|
input_quant: BaseModel) -> "CompressedTensorsScheme":
|
||||||
is_symmetric = weight_symmetric and input_symmetric
|
if self._is_static_tensor_w8a8(weight_quant, input_quant):
|
||||||
|
|
||||||
if is_8_bits and is_tensor and is_symmetric and \
|
|
||||||
torch.cuda.is_available():
|
|
||||||
# CompressedTensorsW8A8StaticTensor only supports CUDA path for
|
|
||||||
# now.
|
|
||||||
return CompressedTensorsW8A8StaticTensor()
|
return CompressedTensorsW8A8StaticTensor()
|
||||||
raise NotImplementedError(
|
|
||||||
"Scheme not supported. Only CUDA, 8-bit static symmtetric "
|
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
|
||||||
"per tensor quantization is currently supported")
|
return CompressedTensorsW8A8DynamicToken()
|
||||||
|
|
||||||
|
raise NotImplementedError("Scheme not supported.")
|
||||||
|
|
||||||
def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme":
|
def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme":
|
||||||
|
|
||||||
# TODO: update with matching function from `compressed_tensors`
|
layer_type_name = find_first_name_or_class_match(
|
||||||
layer_type_name = None
|
name="",
|
||||||
layer_name_class = type(layer).__name__.lower()
|
module=layer,
|
||||||
for target in self.layer_quant_details:
|
targets=self.layer_quant_details.keys(),
|
||||||
if target.lower() in layer_name_class:
|
check_contains=True)
|
||||||
layer_type_name = target
|
|
||||||
break
|
|
||||||
if layer_type_name is None:
|
if layer_type_name is None:
|
||||||
raise ValueError(f"Could not matching target for layer {layer}")
|
raise ValueError(f"Could not matching target for layer {layer}")
|
||||||
|
|
||||||
@ -117,7 +129,9 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
|
|||||||
**extra_weight_attrs):
|
**extra_weight_attrs):
|
||||||
"""
|
"""
|
||||||
Use the CompressedTensorsScheme associated with each layer to create
|
Use the CompressedTensorsScheme associated with each layer to create
|
||||||
the necessary parameters for the layer.
|
the necessary parameters for the layer. See LinearMethodBase for param
|
||||||
|
details
|
||||||
|
|
||||||
"""
|
"""
|
||||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||||
|
|
||||||
@ -139,7 +153,8 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
|
|||||||
"""
|
"""
|
||||||
Use the output of create_weights and the CompressedTensorsScheme
|
Use the output of create_weights and the CompressedTensorsScheme
|
||||||
associated with the layer to apply the forward pass with the
|
associated with the layer to apply the forward pass with the
|
||||||
layer input.
|
layer input. See LinearMethodBase for param details
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401
|
from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401
|
||||||
from .compressed_tensors_unquantized import ( # noqa: F401
|
from .compressed_tensors_unquantized import ( # noqa: F401
|
||||||
CompressedTensorsUnquantized)
|
CompressedTensorsUnquantized)
|
||||||
|
from .compressed_tensors_w8a8_dynamictoken import ( # noqa: F401, E501
|
||||||
|
CompressedTensorsW8A8DynamicToken)
|
||||||
from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501
|
from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501
|
||||||
CompressedTensorsW8A8StaticTensor)
|
CompressedTensorsW8A8StaticTensor)
|
||||||
|
@ -0,0 +1,85 @@
|
|||||||
|
from typing import Callable, List, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn import Parameter
|
||||||
|
|
||||||
|
from vllm import _custom_ops as custom_ops
|
||||||
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
|
CompressedTensorsScheme)
|
||||||
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
|
||||||
|
__all__ = ["CompressedTensorsW8A8DynamicToken"]
|
||||||
|
|
||||||
|
|
||||||
|
class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):
|
||||||
|
|
||||||
|
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
|
||||||
|
if isinstance(shard_id, int):
|
||||||
|
return shard_id
|
||||||
|
|
||||||
|
assert isinstance(shard_id, str)
|
||||||
|
qkv_idxs = {"q": 0, "k": 1, "v": 2}
|
||||||
|
assert shard_id in qkv_idxs
|
||||||
|
return qkv_idxs[shard_id]
|
||||||
|
|
||||||
|
def scales_shard_splitter(
|
||||||
|
self, param: torch.Tensor, loaded_weight: torch.Tensor,
|
||||||
|
shard_id: Union[str, int],
|
||||||
|
logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
shard_id = self._shard_id_as_int(shard_id)
|
||||||
|
offset = sum(logical_widths[:shard_id])
|
||||||
|
size = logical_widths[shard_id]
|
||||||
|
# update loaded weight with copies for broadcast.
|
||||||
|
loaded_weight = loaded_weight.repeat(size)
|
||||||
|
return param[offset:offset + size], loaded_weight
|
||||||
|
|
||||||
|
def create_weights(self, layer: torch.nn.Module,
|
||||||
|
output_partition_sizes: List[int],
|
||||||
|
input_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype, weight_loader: Callable,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
# When the scales have a single value, it is required that they be
|
||||||
|
# on the CPU for performance and CUDA Graphs compatibility. Please
|
||||||
|
# refer to the comment in
|
||||||
|
# CompressedTensorsW8A8StaticTensor::create_weights for further
|
||||||
|
# information.
|
||||||
|
is_tensor_partitioned = len(output_partition_sizes) != 1
|
||||||
|
weight_scale_dim = sum(
|
||||||
|
output_partition_sizes) if is_tensor_partitioned else 1
|
||||||
|
|
||||||
|
weight_zero_point = Parameter(torch.empty(1, dtype=torch.int8),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
weight_scale = Parameter(torch.empty(weight_scale_dim,
|
||||||
|
dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
||||||
|
input_size_per_partition,
|
||||||
|
dtype=torch.int8),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
layer.register_parameter("weight", weight)
|
||||||
|
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||||
|
set_weight_attrs(weight, {"weight_loader": weight_loader})
|
||||||
|
set_weight_attrs(weight, {"logical_widths": output_partition_sizes})
|
||||||
|
|
||||||
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
|
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
|
||||||
|
set_weight_attrs(
|
||||||
|
weight_scale, {
|
||||||
|
"shard_splitter": self.scales_shard_splitter,
|
||||||
|
"logical_widths": output_partition_sizes
|
||||||
|
})
|
||||||
|
|
||||||
|
layer.register_parameter("weight_zero_point", weight_zero_point)
|
||||||
|
set_weight_attrs(weight_zero_point, {"weight_loader": weight_loader})
|
||||||
|
|
||||||
|
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
|
||||||
|
weight = layer.weight
|
||||||
|
weight_scale = layer.weight_scale
|
||||||
|
|
||||||
|
x_q, input_scales = custom_ops.scaled_int8_quant(x)
|
||||||
|
return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), input_scales,
|
||||||
|
weight_scale, x.dtype)
|
@ -97,7 +97,7 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
|
|||||||
act_scale = layer.input_scale
|
act_scale = layer.input_scale
|
||||||
|
|
||||||
# Input quantize
|
# Input quantize
|
||||||
x_q = custom_ops.static_scaled_int8_quant(x, act_scale)
|
x_q, _ = custom_ops.scaled_int8_quant(x, act_scale)
|
||||||
|
|
||||||
return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale,
|
return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale,
|
||||||
weight_scale, x.dtype)
|
weight_scale, x.dtype)
|
||||||
|
@ -0,0 +1,114 @@
|
|||||||
|
import re
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, Iterable, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from torch.nn import Module
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizationType(str, Enum):
|
||||||
|
"""
|
||||||
|
Enum storing quantization type options
|
||||||
|
"""
|
||||||
|
|
||||||
|
INT = "int"
|
||||||
|
FLOAT = "float"
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizationStrategy(str, Enum):
|
||||||
|
"""
|
||||||
|
Enum storing quantization strategy options
|
||||||
|
"""
|
||||||
|
|
||||||
|
TENSOR = "tensor"
|
||||||
|
CHANNEL = "channel"
|
||||||
|
GROUP = "group"
|
||||||
|
BLOCK = "block"
|
||||||
|
TOKEN = "token"
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizationArgs(BaseModel):
|
||||||
|
"""
|
||||||
|
User facing arguments used to define a quantization config
|
||||||
|
for weights or activations
|
||||||
|
|
||||||
|
:param num_bits: quantization bit depth
|
||||||
|
:param type: dtype to quantized to, either int or float
|
||||||
|
:param symmetric: whether or not quantization scale is symmetric
|
||||||
|
:param strategy: string determining the scope of scale/zero-point to apply
|
||||||
|
:param group_size: group length to use for the group strategy
|
||||||
|
:param block_structure: 2d block structure to use for the block
|
||||||
|
strategy, must be of the format "2x4", "8x16", etc.
|
||||||
|
:param dynamic: set True to perform dynamic quantization -
|
||||||
|
values will not be calibrated during calibration phase,
|
||||||
|
instead during inference new quantization ranges will be
|
||||||
|
observed with every sample. Defaults to False for static
|
||||||
|
quantization. Note that enabling dynamic quantization
|
||||||
|
will change the default observer to a memoryless one
|
||||||
|
"""
|
||||||
|
|
||||||
|
num_bits: int = 8
|
||||||
|
type: QuantizationType = QuantizationType.INT
|
||||||
|
symmetric: bool = True
|
||||||
|
group_size: Optional[int] = None
|
||||||
|
strategy: Optional[QuantizationStrategy] = None
|
||||||
|
block_structure: Optional[str] = None
|
||||||
|
dynamic: bool = False
|
||||||
|
observer: str = Field(
|
||||||
|
default="minmax",
|
||||||
|
description=("The class to use to compute the quantization param - "
|
||||||
|
"scale and zero-point'"),
|
||||||
|
)
|
||||||
|
observer_kwargs: Dict[str, Any] = Field(
|
||||||
|
default_factory=dict,
|
||||||
|
description=
|
||||||
|
("optional dict of kwargs to be passed directly to torch quantization "
|
||||||
|
"Observers constructor excluding quantization range or symmetry"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def find_first_name_or_class_match(
|
||||||
|
name: str,
|
||||||
|
module: Module,
|
||||||
|
targets: Iterable[str],
|
||||||
|
check_contains: bool = False) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Helper function to map the quantization details listed in the config
|
||||||
|
for a given list of targets against each model layer. First uses the
|
||||||
|
layer name to try and find a match. If no name match is found, uses
|
||||||
|
the layer class name. Returns None otherwise.
|
||||||
|
|
||||||
|
:param name: layer name
|
||||||
|
:param module: torch.nn.Module
|
||||||
|
:param targets: list of targets to match the layer against
|
||||||
|
:param check_contains: whether or not to do a substring match
|
||||||
|
"""
|
||||||
|
|
||||||
|
return _find_first_match(name, targets) or _find_first_match(
|
||||||
|
module.__class__.__name__, targets, check_contains)
|
||||||
|
|
||||||
|
|
||||||
|
def _find_first_match(value: str,
|
||||||
|
targets: Iterable[str],
|
||||||
|
check_contains: bool = False) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Returns first element of target that matches value either
|
||||||
|
exactly or as a regex after 're:'. If check_contains is set to True,
|
||||||
|
additionally checks if the target string is contained within the value.
|
||||||
|
|
||||||
|
:param value: string to compare the list of targets against
|
||||||
|
:param targets: list of targets to match the layer against
|
||||||
|
:param check_contains: whether or not to do a substring match
|
||||||
|
"""
|
||||||
|
|
||||||
|
for target in targets:
|
||||||
|
if target.startswith("re:"):
|
||||||
|
pattern = target[3:]
|
||||||
|
if re.match(pattern, value):
|
||||||
|
return target
|
||||||
|
elif check_contains:
|
||||||
|
if target.lower() in value.lower():
|
||||||
|
return target
|
||||||
|
elif target == value:
|
||||||
|
return target
|
||||||
|
return None
|
Loading…
x
Reference in New Issue
Block a user