[Kernel] Dynamic Per-Token Activation Quantization (#5037)

Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Dipika Sikka 2024-06-07 12:36:26 -04:00 committed by GitHub
parent dc49fb892c
commit ca3ea51bde
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 439 additions and 75 deletions

View File

@ -97,6 +97,9 @@ int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale);
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& scales);
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor lookup_table);

View File

@ -70,6 +70,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def("static_scaled_int8_quant", &static_scaled_int8_quant,
"Compute int8 quantized tensor for given scaling factor");
ops.def("dynamic_scaled_int8_quant", &dynamic_scaled_int8_quant,
"Compute int8 quantized tensor and scaling factor");
// Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def("swap_blocks", &swap_blocks,

View File

@ -3,6 +3,7 @@
#include <cmath>
#include "../../dispatch_utils.h"
#include "../../reduction_utils.cuh"
static inline __device__ int8_t float_to_int8_rn(float x) {
#ifdef USE_ROCM
@ -27,17 +28,48 @@ namespace vllm {
template <typename scalar_t, typename scale_type>
__global__ void static_scaled_int8_quant_kernel(
const scalar_t* __restrict__ input, int8_t* __restrict__ out,
const scale_type* scale_ptr, const int hidden_size) {
const int tid = threadIdx.x;
const int token_idx = blockIdx.x;
scale_type scale = *scale_ptr;
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type const* scale_ptr, const int hidden_size) {
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
scale_type const scale = *scale_ptr;
for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] =
float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale);
out[token_idx * hidden_size + i] = float_to_int8_rn(
static_cast<float>(input[token_idx * hidden_size + i]) / scale);
}
}
template <typename scalar_t, typename scale_type>
__global__ void dynamic_scaled_int8_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type* scale, const int hidden_size) {
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
float absmax_val = 0.0f;
float const zero = 0.0f;
for (int i = tid; i < hidden_size; i += blockDim.x) {
float val = static_cast<float>(input[token_idx * hidden_size + i]);
val = val > zero ? val : -val;
absmax_val = val > absmax_val ? val : absmax_val;
}
float const block_absmax_val_maybe = blockReduceMax(absmax_val);
__shared__ float block_absmax_val;
if (tid == 0) {
block_absmax_val = block_absmax_val_maybe;
scale[token_idx] = block_absmax_val / 127.0f;
}
__syncthreads();
float const tmp_scale = 127.0f / block_absmax_val;
for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] = float_to_int8_rn(
static_cast<float>(input[token_idx * hidden_size + i]) * tmp_scale);
}
}
} // namespace vllm
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
@ -47,10 +79,10 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(scale.numel() == 1);
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
@ -60,3 +92,24 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
scale.data_ptr<float>(), hidden_size);
});
}
void dynamic_scaled_int8_quant(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor const& input, // [..., hidden_size]
torch::Tensor& scales) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
vllm::dynamic_scaled_int8_quant_kernel<scalar_t, float>
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
out.data_ptr<int8_t>(),
scales.data_ptr<float>(), hidden_size);
});
}

View File

@ -21,29 +21,47 @@
#include "cuda_compat.h"
namespace vllm {
template <typename T, int numLanes = WARP_SIZE>
__inline__ __device__ T warpReduceSum(T val) {
static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
"numLanes is not a positive power of 2!");
static_assert(numLanes <= WARP_SIZE);
#pragma unroll
for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
val += VLLM_SHFL_XOR_SYNC(val, mask);
return val;
namespace detail {
template <typename T>
__inline__ __device__ T _max(T a, T b) {
return max(a, b);
}
template <typename T>
__inline__ __device__ T _sum(T a, T b) {
return a + b;
}
} // namespace detail
template <typename T>
using ReduceFnType = T (*)(T, T);
// Helper function to return the next largest power of 2
static constexpr int _nextPow2(unsigned int num) {
if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}
/* Calculate the sum of all elements in a block */
template <typename T, int numLanes = WARP_SIZE>
__inline__ __device__ T warpReduce(T val, ReduceFnType<T> fn) {
static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
"numLanes is not a positive power of 2!");
static_assert(numLanes <= WARP_SIZE);
#pragma unroll
for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
val = fn(val, VLLM_SHFL_XOR_SYNC(val, mask));
return val;
}
template <typename T, int maxBlockSize = 1024>
__inline__ __device__ T blockReduceSum(T val) {
__inline__ __device__ T blockReduce(T val, ReduceFnType<T> fn) {
static_assert(maxBlockSize <= 1024);
if constexpr (maxBlockSize > WARP_SIZE) {
val = warpReduceSum<T>(val);
val = warpReduce<T>(val, fn);
// Calculates max number of lanes that need to participate in the last
// warpReduce
constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE;
@ -56,12 +74,22 @@ __inline__ __device__ T blockReduceSum(T val) {
val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane]
: (T)(0.0f);
val = warpReduceSum<T, _nextPow2(maxActiveLanes)>(val);
val = warpReduce<T, _nextPow2(maxActiveLanes)>(val, fn);
} else {
// A single warpReduce is equal to blockReduce
val = warpReduceSum<T, _nextPow2(maxBlockSize)>(val);
val = warpReduce<T, _nextPow2(maxBlockSize)>(val, fn);
}
return val;
}
template <typename T, int maxBlockSize = 1024>
__inline__ __device__ T blockReduceMax(T val) {
return blockReduce<T, maxBlockSize>(val, detail::_max<T>);
}
template <typename T, int maxBlockSize = 1024>
__inline__ __device__ T blockReduceSum(T val) {
return blockReduce<T, maxBlockSize>(val, detail::_sum<T>);
}
} // namespace vllm

View File

@ -4,7 +4,8 @@ import torch
from vllm._C import ops
DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 8192] # Arbitrary values for testing
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192,
8193] # Arbitrary values for testing
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
SEEDS = [0]
SCALE = [0.1, 0.5, 0.8, 1.2, 2.1]
@ -14,17 +15,48 @@ SCALE = [0.1, 0.5, 0.8, 1.2, 2.1]
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("scale", SCALE)
@torch.inference_mode()
def test_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype,
seed: int, scale: float) -> None:
def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
int8_traits = torch.iinfo(torch.int8)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
out1 = (x / scale).round().clamp(
torch.iinfo(torch.int8).min,
torch.iinfo(torch.int8).max).to(torch.int8)
x_token_max, _ = x.max(dim=1)
x_token_max = x_token_max.to(dtype=torch.float32)
scales = (x_token_max / float(127.0))[:, None].to(device="cuda",
dtype=torch.float32)
torch_out = (x / scales).round().clamp(int8_traits.min,
int8_traits.max).to(torch.int8)
ops_out = torch.empty_like(x, dtype=torch.int8, device="cuda")
scales_out = torch.empty_like(scales, dtype=torch.float32, device="cuda")
ops.dynamic_scaled_int8_quant(ops_out, x, scales_out)
assert torch.allclose(scales_out, scales)
assert torch.allclose(torch_out, ops_out,
atol=1) # big atol to account for rounding errors
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("scale", SCALE)
@torch.inference_mode()
def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int,
scale: float) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
int8_traits = torch.iinfo(torch.int8)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
out1 = (x / scale).round().clamp(int8_traits.min,
int8_traits.max).to(torch.int8)
out2 = torch.empty_like(x, dtype=torch.int8)
scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda")

View File

@ -6,7 +6,8 @@ Run `pytest tests/quantization/test_compressed_tensors.py`.
import torch
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod, CompressedTensorsW8A8StaticTensor)
CompressedTensorsLinearMethod, CompressedTensorsW8A8DynamicToken,
CompressedTensorsW8A8StaticTensor)
def test_compressed_tensors_w8a8_static_setup(vllm_runner):
@ -34,3 +35,19 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner):
assert qkv_proj.weight_scale.shard_splitter is not None
assert qkv_proj.weight_scale.logical_widths is not None
assert qkv_proj.input_scale.dtype is torch.float32
def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner):
model_path = "nm-testing/tinyllama-one-shot-dynamic-test"
llm = vllm_runner(model_path,
quantization="sparseml",
enforce_eager=True,
dtype=torch.float16)
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8DynamicToken)
assert qkv_proj.weight.dtype is torch.int8

View File

@ -266,21 +266,33 @@ def scaled_fp8_quant(
# int8
def static_scaled_int8_quant(input: torch.Tensor,
scale: torch.Tensor) -> torch.Tensor:
def scaled_int8_quant(
input: torch.Tensor,
scale: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantize the input tensor to int8 and return the quantized tensor.
Quantize the input tensor to int8 and return the quantized tensor and scale.
Args:
input: The input tensor to be quantized to int8.
scale: Scaling factor for the int8 quantization.
scale: Optional scaling factor for the int8 quantization.
When not provided, we invoke dynamic-per-token quantization.
Returns:
torch.Tensor: Output tensor in int8.
Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales.
"""
q = torch.empty_like(input, dtype=torch.int8)
vllm_ops.static_scaled_int8_quant(q, input, scale)
return q
output = torch.empty_like(input, dtype=torch.int8)
if scale is not None:
# static-per-tensor quantization.
vllm_ops.static_scaled_int8_quant(output, input, scale)
return output, scale
# dynamic-per-token quantization.
input_scales = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)
vllm_ops.dynamic_scaled_int8_quant(output, input, input_scales)
return output, input_scales
# moe

View File

@ -1,12 +1,16 @@
from typing import Any, Dict, List, Optional
import torch
from pydantic import BaseModel
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme, CompressedTensorsW8A8StaticTensor)
CompressedTensorsScheme, CompressedTensorsW8A8DynamicToken,
CompressedTensorsW8A8StaticTensor)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationArgs, QuantizationStrategy, find_first_name_or_class_match)
class CompressedTensorsConfig(QuantizationConfig):
@ -47,10 +51,12 @@ class CompressedTensorsConfig(QuantizationConfig):
targets = quant_config.get("targets")
for target in targets:
layer_quant_details[target] = {}
layer_quant_details[target]["weight"] = quant_config.get(
"weights")
layer_quant_details[target]["input"] = quant_config.get(
"input_activations")
layer_quant_details[target][
"weight"] = QuantizationArgs.parse_obj(
quant_config.get("weights"))
layer_quant_details[target][
"input"] = QuantizationArgs.parse_obj(
quant_config.get("input_activations"))
return cls(layer_quant_details=layer_quant_details, ignore=ignore)
@ -58,40 +64,46 @@ class CompressedTensorsConfig(QuantizationConfig):
def get_config_filenames(cls) -> List[str]:
return []
def _get_schema(self, weight_quant: Dict, input_quant: Dict):
# TODO: Refactor as additional cases are supported
def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
is_tensor = (weight_quant.strategy == input_quant.strategy ==
QuantizationStrategy.TENSOR.value)
is_symmetric = weight_quant.symmetric and input_quant.symmetric
is_static = not weight_quant.dynamic and not input_quant.dynamic
weight_bit = weight_quant.get("num_bits")
input_bit = input_quant.get("num_bits")
return is_8_bits and is_tensor and is_symmetric and is_static
weight_strategy = weight_quant.get("strategy")
input_strategy = input_quant.get("strategy")
def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
is_token_tensor = (weight_quant.strategy
== QuantizationStrategy.TENSOR.value) and (
input_quant.strategy
== QuantizationStrategy.TOKEN.value)
is_symmetric = weight_quant.symmetric and input_quant.symmetric
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
weight_symmetric = weight_quant.get("symmetric")
input_symmetric = input_quant.get("symmetric")
return is_8_bits and is_token_tensor and is_symmetric and is_dynamic
is_8_bits = weight_bit == input_bit == 8
is_tensor = weight_strategy == input_strategy == "tensor"
is_symmetric = weight_symmetric and input_symmetric
if is_8_bits and is_tensor and is_symmetric and \
torch.cuda.is_available():
# CompressedTensorsW8A8StaticTensor only supports CUDA path for
# now.
def _get_schema(self, weight_quant: BaseModel,
input_quant: BaseModel) -> "CompressedTensorsScheme":
if self._is_static_tensor_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8StaticTensor()
raise NotImplementedError(
"Scheme not supported. Only CUDA, 8-bit static symmtetric "
"per tensor quantization is currently supported")
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8DynamicToken()
raise NotImplementedError("Scheme not supported.")
def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme":
# TODO: update with matching function from `compressed_tensors`
layer_type_name = None
layer_name_class = type(layer).__name__.lower()
for target in self.layer_quant_details:
if target.lower() in layer_name_class:
layer_type_name = target
break
layer_type_name = find_first_name_or_class_match(
name="",
module=layer,
targets=self.layer_quant_details.keys(),
check_contains=True)
if layer_type_name is None:
raise ValueError(f"Could not matching target for layer {layer}")
@ -117,7 +129,9 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
**extra_weight_attrs):
"""
Use the CompressedTensorsScheme associated with each layer to create
the necessary parameters for the layer.
the necessary parameters for the layer. See LinearMethodBase for param
details
"""
weight_loader = extra_weight_attrs.get("weight_loader")
@ -139,7 +153,8 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
"""
Use the output of create_weights and the CompressedTensorsScheme
associated with the layer to apply the forward pass with the
layer input.
layer input. See LinearMethodBase for param details
"""
if bias is not None:

View File

@ -1,5 +1,7 @@
from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401
from .compressed_tensors_unquantized import ( # noqa: F401
CompressedTensorsUnquantized)
from .compressed_tensors_w8a8_dynamictoken import ( # noqa: F401, E501
CompressedTensorsW8A8DynamicToken)
from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501
CompressedTensorsW8A8StaticTensor)

View File

@ -0,0 +1,85 @@
from typing import Callable, List, Tuple, Union
import torch
from torch.nn import Parameter
from vllm import _custom_ops as custom_ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.utils import set_weight_attrs
__all__ = ["CompressedTensorsW8A8DynamicToken"]
class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
if isinstance(shard_id, int):
return shard_id
assert isinstance(shard_id, str)
qkv_idxs = {"q": 0, "k": 1, "v": 2}
assert shard_id in qkv_idxs
return qkv_idxs[shard_id]
def scales_shard_splitter(
self, param: torch.Tensor, loaded_weight: torch.Tensor,
shard_id: Union[str, int],
logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
shard_id = self._shard_id_as_int(shard_id)
offset = sum(logical_widths[:shard_id])
size = logical_widths[shard_id]
# update loaded weight with copies for broadcast.
loaded_weight = loaded_weight.repeat(size)
return param[offset:offset + size], loaded_weight
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
# When the scales have a single value, it is required that they be
# on the CPU for performance and CUDA Graphs compatibility. Please
# refer to the comment in
# CompressedTensorsW8A8StaticTensor::create_weights for further
# information.
is_tensor_partitioned = len(output_partition_sizes) != 1
weight_scale_dim = sum(
output_partition_sizes) if is_tensor_partitioned else 1
weight_zero_point = Parameter(torch.empty(1, dtype=torch.int8),
requires_grad=False)
weight_scale = Parameter(torch.empty(weight_scale_dim,
dtype=torch.float32),
requires_grad=False)
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=torch.int8),
requires_grad=False)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
set_weight_attrs(weight, {"weight_loader": weight_loader})
set_weight_attrs(weight, {"logical_widths": output_partition_sizes})
layer.register_parameter("weight_scale", weight_scale)
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
set_weight_attrs(
weight_scale, {
"shard_splitter": self.scales_shard_splitter,
"logical_widths": output_partition_sizes
})
layer.register_parameter("weight_zero_point", weight_zero_point)
set_weight_attrs(weight_zero_point, {"weight_loader": weight_loader})
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
weight = layer.weight
weight_scale = layer.weight_scale
x_q, input_scales = custom_ops.scaled_int8_quant(x)
return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), input_scales,
weight_scale, x.dtype)

View File

@ -97,7 +97,7 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
act_scale = layer.input_scale
# Input quantize
x_q = custom_ops.static_scaled_int8_quant(x, act_scale)
x_q, _ = custom_ops.scaled_int8_quant(x, act_scale)
return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale,
weight_scale, x.dtype)

View File

@ -0,0 +1,114 @@
import re
from enum import Enum
from typing import Any, Dict, Iterable, Optional
from pydantic import BaseModel, Field
from torch.nn import Module
class QuantizationType(str, Enum):
"""
Enum storing quantization type options
"""
INT = "int"
FLOAT = "float"
class QuantizationStrategy(str, Enum):
"""
Enum storing quantization strategy options
"""
TENSOR = "tensor"
CHANNEL = "channel"
GROUP = "group"
BLOCK = "block"
TOKEN = "token"
class QuantizationArgs(BaseModel):
"""
User facing arguments used to define a quantization config
for weights or activations
:param num_bits: quantization bit depth
:param type: dtype to quantized to, either int or float
:param symmetric: whether or not quantization scale is symmetric
:param strategy: string determining the scope of scale/zero-point to apply
:param group_size: group length to use for the group strategy
:param block_structure: 2d block structure to use for the block
strategy, must be of the format "2x4", "8x16", etc.
:param dynamic: set True to perform dynamic quantization -
values will not be calibrated during calibration phase,
instead during inference new quantization ranges will be
observed with every sample. Defaults to False for static
quantization. Note that enabling dynamic quantization
will change the default observer to a memoryless one
"""
num_bits: int = 8
type: QuantizationType = QuantizationType.INT
symmetric: bool = True
group_size: Optional[int] = None
strategy: Optional[QuantizationStrategy] = None
block_structure: Optional[str] = None
dynamic: bool = False
observer: str = Field(
default="minmax",
description=("The class to use to compute the quantization param - "
"scale and zero-point'"),
)
observer_kwargs: Dict[str, Any] = Field(
default_factory=dict,
description=
("optional dict of kwargs to be passed directly to torch quantization "
"Observers constructor excluding quantization range or symmetry"),
)
def find_first_name_or_class_match(
name: str,
module: Module,
targets: Iterable[str],
check_contains: bool = False) -> Optional[str]:
"""
Helper function to map the quantization details listed in the config
for a given list of targets against each model layer. First uses the
layer name to try and find a match. If no name match is found, uses
the layer class name. Returns None otherwise.
:param name: layer name
:param module: torch.nn.Module
:param targets: list of targets to match the layer against
:param check_contains: whether or not to do a substring match
"""
return _find_first_match(name, targets) or _find_first_match(
module.__class__.__name__, targets, check_contains)
def _find_first_match(value: str,
targets: Iterable[str],
check_contains: bool = False) -> Optional[str]:
"""
Returns first element of target that matches value either
exactly or as a regex after 're:'. If check_contains is set to True,
additionally checks if the target string is contained within the value.
:param value: string to compare the list of targets against
:param targets: list of targets to match the layer against
:param check_contains: whether or not to do a substring match
"""
for target in targets:
if target.startswith("re:"):
pattern = target[3:]
if re.match(pattern, value):
return target
elif check_contains:
if target.lower() in value.lower():
return target
elif target == value:
return target
return None