[Kernel] Dynamic Per-Token Activation Quantization (#5037)

Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Dipika Sikka 2024-06-07 12:36:26 -04:00 committed by GitHub
parent dc49fb892c
commit ca3ea51bde
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 439 additions and 75 deletions

View File

@ -97,6 +97,9 @@ int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor const& scale); torch::Tensor const& scale);
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor& scales);
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor lookup_table); torch::Tensor lookup_table);

View File

@ -70,6 +70,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def("static_scaled_int8_quant", &static_scaled_int8_quant, ops.def("static_scaled_int8_quant", &static_scaled_int8_quant,
"Compute int8 quantized tensor for given scaling factor"); "Compute int8 quantized tensor for given scaling factor");
ops.def("dynamic_scaled_int8_quant", &dynamic_scaled_int8_quant,
"Compute int8 quantized tensor and scaling factor");
// Cache ops // Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def("swap_blocks", &swap_blocks, cache_ops.def("swap_blocks", &swap_blocks,

View File

@ -3,6 +3,7 @@
#include <cmath> #include <cmath>
#include "../../dispatch_utils.h" #include "../../dispatch_utils.h"
#include "../../reduction_utils.cuh"
static inline __device__ int8_t float_to_int8_rn(float x) { static inline __device__ int8_t float_to_int8_rn(float x) {
#ifdef USE_ROCM #ifdef USE_ROCM
@ -27,17 +28,48 @@ namespace vllm {
template <typename scalar_t, typename scale_type> template <typename scalar_t, typename scale_type>
__global__ void static_scaled_int8_quant_kernel( __global__ void static_scaled_int8_quant_kernel(
const scalar_t* __restrict__ input, int8_t* __restrict__ out, scalar_t const* __restrict__ input, int8_t* __restrict__ out,
const scale_type* scale_ptr, const int hidden_size) { scale_type const* scale_ptr, const int hidden_size) {
const int tid = threadIdx.x; int const tid = threadIdx.x;
const int token_idx = blockIdx.x; int const token_idx = blockIdx.x;
scale_type scale = *scale_ptr; scale_type const scale = *scale_ptr;
for (int i = tid; i < hidden_size; i += blockDim.x) { for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] = out[token_idx * hidden_size + i] = float_to_int8_rn(
float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale); static_cast<float>(input[token_idx * hidden_size + i]) / scale);
} }
} }
template <typename scalar_t, typename scale_type>
__global__ void dynamic_scaled_int8_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type* scale, const int hidden_size) {
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
float absmax_val = 0.0f;
float const zero = 0.0f;
for (int i = tid; i < hidden_size; i += blockDim.x) {
float val = static_cast<float>(input[token_idx * hidden_size + i]);
val = val > zero ? val : -val;
absmax_val = val > absmax_val ? val : absmax_val;
}
float const block_absmax_val_maybe = blockReduceMax(absmax_val);
__shared__ float block_absmax_val;
if (tid == 0) {
block_absmax_val = block_absmax_val_maybe;
scale[token_idx] = block_absmax_val / 127.0f;
}
__syncthreads();
float const tmp_scale = 127.0f / block_absmax_val;
for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] = float_to_int8_rn(
static_cast<float>(input[token_idx * hidden_size + i]) * tmp_scale);
}
}
} // namespace vllm } // namespace vllm
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
@ -47,10 +79,10 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(scale.numel() == 1); TORCH_CHECK(scale.numel() == 1);
int hidden_size = input.size(-1); int const hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size; int const num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens); dim3 const grid(num_tokens);
dim3 block(std::min(hidden_size, 1024)); dim3 const block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
@ -60,3 +92,24 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
scale.data_ptr<float>(), hidden_size); scale.data_ptr<float>(), hidden_size);
}); });
} }
void dynamic_scaled_int8_quant(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor const& input, // [..., hidden_size]
torch::Tensor& scales) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
vllm::dynamic_scaled_int8_quant_kernel<scalar_t, float>
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
out.data_ptr<int8_t>(),
scales.data_ptr<float>(), hidden_size);
});
}

View File

@ -21,29 +21,47 @@
#include "cuda_compat.h" #include "cuda_compat.h"
namespace vllm { namespace vllm {
template <typename T, int numLanes = WARP_SIZE>
__inline__ __device__ T warpReduceSum(T val) { namespace detail {
static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
"numLanes is not a positive power of 2!"); template <typename T>
static_assert(numLanes <= WARP_SIZE); __inline__ __device__ T _max(T a, T b) {
#pragma unroll return max(a, b);
for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
val += VLLM_SHFL_XOR_SYNC(val, mask);
return val;
} }
template <typename T>
__inline__ __device__ T _sum(T a, T b) {
return a + b;
}
} // namespace detail
template <typename T>
using ReduceFnType = T (*)(T, T);
// Helper function to return the next largest power of 2 // Helper function to return the next largest power of 2
static constexpr int _nextPow2(unsigned int num) { static constexpr int _nextPow2(unsigned int num) {
if (num <= 1) return num; if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
} }
/* Calculate the sum of all elements in a block */ template <typename T, int numLanes = WARP_SIZE>
__inline__ __device__ T warpReduce(T val, ReduceFnType<T> fn) {
static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
"numLanes is not a positive power of 2!");
static_assert(numLanes <= WARP_SIZE);
#pragma unroll
for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
val = fn(val, VLLM_SHFL_XOR_SYNC(val, mask));
return val;
}
template <typename T, int maxBlockSize = 1024> template <typename T, int maxBlockSize = 1024>
__inline__ __device__ T blockReduceSum(T val) { __inline__ __device__ T blockReduce(T val, ReduceFnType<T> fn) {
static_assert(maxBlockSize <= 1024); static_assert(maxBlockSize <= 1024);
if constexpr (maxBlockSize > WARP_SIZE) { if constexpr (maxBlockSize > WARP_SIZE) {
val = warpReduceSum<T>(val); val = warpReduce<T>(val, fn);
// Calculates max number of lanes that need to participate in the last // Calculates max number of lanes that need to participate in the last
// warpReduce // warpReduce
constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE; constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE;
@ -56,12 +74,22 @@ __inline__ __device__ T blockReduceSum(T val) {
val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane] val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane]
: (T)(0.0f); : (T)(0.0f);
val = warpReduceSum<T, _nextPow2(maxActiveLanes)>(val); val = warpReduce<T, _nextPow2(maxActiveLanes)>(val, fn);
} else { } else {
// A single warpReduce is equal to blockReduce // A single warpReduce is equal to blockReduce
val = warpReduceSum<T, _nextPow2(maxBlockSize)>(val); val = warpReduce<T, _nextPow2(maxBlockSize)>(val, fn);
} }
return val; return val;
} }
template <typename T, int maxBlockSize = 1024>
__inline__ __device__ T blockReduceMax(T val) {
return blockReduce<T, maxBlockSize>(val, detail::_max<T>);
}
template <typename T, int maxBlockSize = 1024>
__inline__ __device__ T blockReduceSum(T val) {
return blockReduce<T, maxBlockSize>(val, detail::_sum<T>);
}
} // namespace vllm } // namespace vllm

View File

@ -4,7 +4,8 @@ import torch
from vllm._C import ops from vllm._C import ops
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 8192] # Arbitrary values for testing HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192,
8193] # Arbitrary values for testing
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
SEEDS = [0] SEEDS = [0]
SCALE = [0.1, 0.5, 0.8, 1.2, 2.1] SCALE = [0.1, 0.5, 0.8, 1.2, 2.1]
@ -14,17 +15,48 @@ SCALE = [0.1, 0.5, 0.8, 1.2, 2.1]
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("scale", SCALE)
@torch.inference_mode() @torch.inference_mode()
def test_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype, def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
seed: int, scale: float) -> None: dtype: torch.dtype, seed: int) -> None:
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
int8_traits = torch.iinfo(torch.int8)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000 x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
out1 = (x / scale).round().clamp( x_token_max, _ = x.max(dim=1)
torch.iinfo(torch.int8).min, x_token_max = x_token_max.to(dtype=torch.float32)
torch.iinfo(torch.int8).max).to(torch.int8) scales = (x_token_max / float(127.0))[:, None].to(device="cuda",
dtype=torch.float32)
torch_out = (x / scales).round().clamp(int8_traits.min,
int8_traits.max).to(torch.int8)
ops_out = torch.empty_like(x, dtype=torch.int8, device="cuda")
scales_out = torch.empty_like(scales, dtype=torch.float32, device="cuda")
ops.dynamic_scaled_int8_quant(ops_out, x, scales_out)
assert torch.allclose(scales_out, scales)
assert torch.allclose(torch_out, ops_out,
atol=1) # big atol to account for rounding errors
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("scale", SCALE)
@torch.inference_mode()
def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
dtype: torch.dtype, seed: int,
scale: float) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
int8_traits = torch.iinfo(torch.int8)
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
out1 = (x / scale).round().clamp(int8_traits.min,
int8_traits.max).to(torch.int8)
out2 = torch.empty_like(x, dtype=torch.int8) out2 = torch.empty_like(x, dtype=torch.int8)
scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda") scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda")

View File

@ -6,7 +6,8 @@ Run `pytest tests/quantization/test_compressed_tensors.py`.
import torch import torch
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod, CompressedTensorsW8A8StaticTensor) CompressedTensorsLinearMethod, CompressedTensorsW8A8DynamicToken,
CompressedTensorsW8A8StaticTensor)
def test_compressed_tensors_w8a8_static_setup(vllm_runner): def test_compressed_tensors_w8a8_static_setup(vllm_runner):
@ -34,3 +35,19 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner):
assert qkv_proj.weight_scale.shard_splitter is not None assert qkv_proj.weight_scale.shard_splitter is not None
assert qkv_proj.weight_scale.logical_widths is not None assert qkv_proj.weight_scale.logical_widths is not None
assert qkv_proj.input_scale.dtype is torch.float32 assert qkv_proj.input_scale.dtype is torch.float32
def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner):
model_path = "nm-testing/tinyllama-one-shot-dynamic-test"
llm = vllm_runner(model_path,
quantization="sparseml",
enforce_eager=True,
dtype=torch.float16)
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8DynamicToken)
assert qkv_proj.weight.dtype is torch.int8

View File

@ -266,21 +266,33 @@ def scaled_fp8_quant(
# int8 # int8
def static_scaled_int8_quant(input: torch.Tensor, def scaled_int8_quant(
scale: torch.Tensor) -> torch.Tensor: input: torch.Tensor,
scale: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Quantize the input tensor to int8 and return the quantized tensor. Quantize the input tensor to int8 and return the quantized tensor and scale.
Args: Args:
input: The input tensor to be quantized to int8. input: The input tensor to be quantized to int8.
scale: Scaling factor for the int8 quantization. scale: Optional scaling factor for the int8 quantization.
When not provided, we invoke dynamic-per-token quantization.
Returns: Returns:
torch.Tensor: Output tensor in int8. Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales.
""" """
q = torch.empty_like(input, dtype=torch.int8) output = torch.empty_like(input, dtype=torch.int8)
vllm_ops.static_scaled_int8_quant(q, input, scale) if scale is not None:
return q # static-per-tensor quantization.
vllm_ops.static_scaled_int8_quant(output, input, scale)
return output, scale
# dynamic-per-token quantization.
input_scales = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)
vllm_ops.dynamic_scaled_int8_quant(output, input, input_scales)
return output, input_scales
# moe # moe

View File

@ -1,12 +1,16 @@
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import torch import torch
from pydantic import BaseModel
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme, CompressedTensorsW8A8StaticTensor) CompressedTensorsScheme, CompressedTensorsW8A8DynamicToken,
CompressedTensorsW8A8StaticTensor)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationArgs, QuantizationStrategy, find_first_name_or_class_match)
class CompressedTensorsConfig(QuantizationConfig): class CompressedTensorsConfig(QuantizationConfig):
@ -47,10 +51,12 @@ class CompressedTensorsConfig(QuantizationConfig):
targets = quant_config.get("targets") targets = quant_config.get("targets")
for target in targets: for target in targets:
layer_quant_details[target] = {} layer_quant_details[target] = {}
layer_quant_details[target]["weight"] = quant_config.get( layer_quant_details[target][
"weights") "weight"] = QuantizationArgs.parse_obj(
layer_quant_details[target]["input"] = quant_config.get( quant_config.get("weights"))
"input_activations") layer_quant_details[target][
"input"] = QuantizationArgs.parse_obj(
quant_config.get("input_activations"))
return cls(layer_quant_details=layer_quant_details, ignore=ignore) return cls(layer_quant_details=layer_quant_details, ignore=ignore)
@ -58,40 +64,46 @@ class CompressedTensorsConfig(QuantizationConfig):
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> List[str]:
return [] return []
def _get_schema(self, weight_quant: Dict, input_quant: Dict): def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
# TODO: Refactor as additional cases are supported input_quant: BaseModel) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
is_tensor = (weight_quant.strategy == input_quant.strategy ==
QuantizationStrategy.TENSOR.value)
is_symmetric = weight_quant.symmetric and input_quant.symmetric
is_static = not weight_quant.dynamic and not input_quant.dynamic
weight_bit = weight_quant.get("num_bits") return is_8_bits and is_tensor and is_symmetric and is_static
input_bit = input_quant.get("num_bits")
weight_strategy = weight_quant.get("strategy") def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
input_strategy = input_quant.get("strategy") input_quant: BaseModel) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
is_token_tensor = (weight_quant.strategy
== QuantizationStrategy.TENSOR.value) and (
input_quant.strategy
== QuantizationStrategy.TOKEN.value)
is_symmetric = weight_quant.symmetric and input_quant.symmetric
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
weight_symmetric = weight_quant.get("symmetric") return is_8_bits and is_token_tensor and is_symmetric and is_dynamic
input_symmetric = input_quant.get("symmetric")
is_8_bits = weight_bit == input_bit == 8 def _get_schema(self, weight_quant: BaseModel,
is_tensor = weight_strategy == input_strategy == "tensor" input_quant: BaseModel) -> "CompressedTensorsScheme":
is_symmetric = weight_symmetric and input_symmetric if self._is_static_tensor_w8a8(weight_quant, input_quant):
if is_8_bits and is_tensor and is_symmetric and \
torch.cuda.is_available():
# CompressedTensorsW8A8StaticTensor only supports CUDA path for
# now.
return CompressedTensorsW8A8StaticTensor() return CompressedTensorsW8A8StaticTensor()
raise NotImplementedError(
"Scheme not supported. Only CUDA, 8-bit static symmtetric " if self._is_dynamic_token_w8a8(weight_quant, input_quant):
"per tensor quantization is currently supported") return CompressedTensorsW8A8DynamicToken()
raise NotImplementedError("Scheme not supported.")
def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme": def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme":
# TODO: update with matching function from `compressed_tensors` layer_type_name = find_first_name_or_class_match(
layer_type_name = None name="",
layer_name_class = type(layer).__name__.lower() module=layer,
for target in self.layer_quant_details: targets=self.layer_quant_details.keys(),
if target.lower() in layer_name_class: check_contains=True)
layer_type_name = target
break
if layer_type_name is None: if layer_type_name is None:
raise ValueError(f"Could not matching target for layer {layer}") raise ValueError(f"Could not matching target for layer {layer}")
@ -117,7 +129,9 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
**extra_weight_attrs): **extra_weight_attrs):
""" """
Use the CompressedTensorsScheme associated with each layer to create Use the CompressedTensorsScheme associated with each layer to create
the necessary parameters for the layer. the necessary parameters for the layer. See LinearMethodBase for param
details
""" """
weight_loader = extra_weight_attrs.get("weight_loader") weight_loader = extra_weight_attrs.get("weight_loader")
@ -139,7 +153,8 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
""" """
Use the output of create_weights and the CompressedTensorsScheme Use the output of create_weights and the CompressedTensorsScheme
associated with the layer to apply the forward pass with the associated with the layer to apply the forward pass with the
layer input. layer input. See LinearMethodBase for param details
""" """
if bias is not None: if bias is not None:

View File

@ -1,5 +1,7 @@
from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401 from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401
from .compressed_tensors_unquantized import ( # noqa: F401 from .compressed_tensors_unquantized import ( # noqa: F401
CompressedTensorsUnquantized) CompressedTensorsUnquantized)
from .compressed_tensors_w8a8_dynamictoken import ( # noqa: F401, E501
CompressedTensorsW8A8DynamicToken)
from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501 from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501
CompressedTensorsW8A8StaticTensor) CompressedTensorsW8A8StaticTensor)

View File

@ -0,0 +1,85 @@
from typing import Callable, List, Tuple, Union
import torch
from torch.nn import Parameter
from vllm import _custom_ops as custom_ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.utils import set_weight_attrs
__all__ = ["CompressedTensorsW8A8DynamicToken"]
class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
if isinstance(shard_id, int):
return shard_id
assert isinstance(shard_id, str)
qkv_idxs = {"q": 0, "k": 1, "v": 2}
assert shard_id in qkv_idxs
return qkv_idxs[shard_id]
def scales_shard_splitter(
self, param: torch.Tensor, loaded_weight: torch.Tensor,
shard_id: Union[str, int],
logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
shard_id = self._shard_id_as_int(shard_id)
offset = sum(logical_widths[:shard_id])
size = logical_widths[shard_id]
# update loaded weight with copies for broadcast.
loaded_weight = loaded_weight.repeat(size)
return param[offset:offset + size], loaded_weight
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
# When the scales have a single value, it is required that they be
# on the CPU for performance and CUDA Graphs compatibility. Please
# refer to the comment in
# CompressedTensorsW8A8StaticTensor::create_weights for further
# information.
is_tensor_partitioned = len(output_partition_sizes) != 1
weight_scale_dim = sum(
output_partition_sizes) if is_tensor_partitioned else 1
weight_zero_point = Parameter(torch.empty(1, dtype=torch.int8),
requires_grad=False)
weight_scale = Parameter(torch.empty(weight_scale_dim,
dtype=torch.float32),
requires_grad=False)
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=torch.int8),
requires_grad=False)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
set_weight_attrs(weight, {"weight_loader": weight_loader})
set_weight_attrs(weight, {"logical_widths": output_partition_sizes})
layer.register_parameter("weight_scale", weight_scale)
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
set_weight_attrs(
weight_scale, {
"shard_splitter": self.scales_shard_splitter,
"logical_widths": output_partition_sizes
})
layer.register_parameter("weight_zero_point", weight_zero_point)
set_weight_attrs(weight_zero_point, {"weight_loader": weight_loader})
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
weight = layer.weight
weight_scale = layer.weight_scale
x_q, input_scales = custom_ops.scaled_int8_quant(x)
return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), input_scales,
weight_scale, x.dtype)

View File

@ -97,7 +97,7 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
act_scale = layer.input_scale act_scale = layer.input_scale
# Input quantize # Input quantize
x_q = custom_ops.static_scaled_int8_quant(x, act_scale) x_q, _ = custom_ops.scaled_int8_quant(x, act_scale)
return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale, return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale,
weight_scale, x.dtype) weight_scale, x.dtype)

View File

@ -0,0 +1,114 @@
import re
from enum import Enum
from typing import Any, Dict, Iterable, Optional
from pydantic import BaseModel, Field
from torch.nn import Module
class QuantizationType(str, Enum):
"""
Enum storing quantization type options
"""
INT = "int"
FLOAT = "float"
class QuantizationStrategy(str, Enum):
"""
Enum storing quantization strategy options
"""
TENSOR = "tensor"
CHANNEL = "channel"
GROUP = "group"
BLOCK = "block"
TOKEN = "token"
class QuantizationArgs(BaseModel):
"""
User facing arguments used to define a quantization config
for weights or activations
:param num_bits: quantization bit depth
:param type: dtype to quantized to, either int or float
:param symmetric: whether or not quantization scale is symmetric
:param strategy: string determining the scope of scale/zero-point to apply
:param group_size: group length to use for the group strategy
:param block_structure: 2d block structure to use for the block
strategy, must be of the format "2x4", "8x16", etc.
:param dynamic: set True to perform dynamic quantization -
values will not be calibrated during calibration phase,
instead during inference new quantization ranges will be
observed with every sample. Defaults to False for static
quantization. Note that enabling dynamic quantization
will change the default observer to a memoryless one
"""
num_bits: int = 8
type: QuantizationType = QuantizationType.INT
symmetric: bool = True
group_size: Optional[int] = None
strategy: Optional[QuantizationStrategy] = None
block_structure: Optional[str] = None
dynamic: bool = False
observer: str = Field(
default="minmax",
description=("The class to use to compute the quantization param - "
"scale and zero-point'"),
)
observer_kwargs: Dict[str, Any] = Field(
default_factory=dict,
description=
("optional dict of kwargs to be passed directly to torch quantization "
"Observers constructor excluding quantization range or symmetry"),
)
def find_first_name_or_class_match(
name: str,
module: Module,
targets: Iterable[str],
check_contains: bool = False) -> Optional[str]:
"""
Helper function to map the quantization details listed in the config
for a given list of targets against each model layer. First uses the
layer name to try and find a match. If no name match is found, uses
the layer class name. Returns None otherwise.
:param name: layer name
:param module: torch.nn.Module
:param targets: list of targets to match the layer against
:param check_contains: whether or not to do a substring match
"""
return _find_first_match(name, targets) or _find_first_match(
module.__class__.__name__, targets, check_contains)
def _find_first_match(value: str,
targets: Iterable[str],
check_contains: bool = False) -> Optional[str]:
"""
Returns first element of target that matches value either
exactly or as a regex after 're:'. If check_contains is set to True,
additionally checks if the target string is contained within the value.
:param value: string to compare the list of targets against
:param targets: list of targets to match the layer against
:param check_contains: whether or not to do a substring match
"""
for target in targets:
if target.startswith("re:"):
pattern = target[3:]
if re.match(pattern, value):
return target
elif check_contains:
if target.lower() in value.lower():
return target
elif target == value:
return target
return None