[Kernel] Dynamic Per-Token Activation Quantization (#5037)
Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
parent
dc49fb892c
commit
ca3ea51bde
@ -97,6 +97,9 @@ int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
|
||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
torch::Tensor const& scale);
|
||||
|
||||
void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
torch::Tensor& scales);
|
||||
|
||||
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
|
||||
torch::Tensor lookup_table);
|
||||
|
||||
|
@ -70,6 +70,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
ops.def("static_scaled_int8_quant", &static_scaled_int8_quant,
|
||||
"Compute int8 quantized tensor for given scaling factor");
|
||||
|
||||
ops.def("dynamic_scaled_int8_quant", &dynamic_scaled_int8_quant,
|
||||
"Compute int8 quantized tensor and scaling factor");
|
||||
|
||||
// Cache ops
|
||||
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
|
||||
cache_ops.def("swap_blocks", &swap_blocks,
|
||||
|
@ -3,6 +3,7 @@
|
||||
#include <cmath>
|
||||
|
||||
#include "../../dispatch_utils.h"
|
||||
#include "../../reduction_utils.cuh"
|
||||
|
||||
static inline __device__ int8_t float_to_int8_rn(float x) {
|
||||
#ifdef USE_ROCM
|
||||
@ -27,17 +28,48 @@ namespace vllm {
|
||||
|
||||
template <typename scalar_t, typename scale_type>
|
||||
__global__ void static_scaled_int8_quant_kernel(
|
||||
const scalar_t* __restrict__ input, int8_t* __restrict__ out,
|
||||
const scale_type* scale_ptr, const int hidden_size) {
|
||||
const int tid = threadIdx.x;
|
||||
const int token_idx = blockIdx.x;
|
||||
scale_type scale = *scale_ptr;
|
||||
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
|
||||
scale_type const* scale_ptr, const int hidden_size) {
|
||||
int const tid = threadIdx.x;
|
||||
int const token_idx = blockIdx.x;
|
||||
scale_type const scale = *scale_ptr;
|
||||
|
||||
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||
out[token_idx * hidden_size + i] =
|
||||
float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale);
|
||||
out[token_idx * hidden_size + i] = float_to_int8_rn(
|
||||
static_cast<float>(input[token_idx * hidden_size + i]) / scale);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename scale_type>
|
||||
__global__ void dynamic_scaled_int8_quant_kernel(
|
||||
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
|
||||
scale_type* scale, const int hidden_size) {
|
||||
int const tid = threadIdx.x;
|
||||
int const token_idx = blockIdx.x;
|
||||
float absmax_val = 0.0f;
|
||||
float const zero = 0.0f;
|
||||
|
||||
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||
float val = static_cast<float>(input[token_idx * hidden_size + i]);
|
||||
val = val > zero ? val : -val;
|
||||
absmax_val = val > absmax_val ? val : absmax_val;
|
||||
}
|
||||
|
||||
float const block_absmax_val_maybe = blockReduceMax(absmax_val);
|
||||
__shared__ float block_absmax_val;
|
||||
if (tid == 0) {
|
||||
block_absmax_val = block_absmax_val_maybe;
|
||||
scale[token_idx] = block_absmax_val / 127.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float const tmp_scale = 127.0f / block_absmax_val;
|
||||
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||
out[token_idx * hidden_size + i] = float_to_int8_rn(
|
||||
static_cast<float>(input[token_idx * hidden_size + i]) * tmp_scale);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
@ -47,10 +79,10 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(scale.numel() == 1);
|
||||
|
||||
int hidden_size = input.size(-1);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
int const hidden_size = input.size(-1);
|
||||
int const num_tokens = input.numel() / hidden_size;
|
||||
dim3 const grid(num_tokens);
|
||||
dim3 const block(std::min(hidden_size, 1024));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "static_scaled_int8_quant_kernel", [&] {
|
||||
@ -60,3 +92,24 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
scale.data_ptr<float>(), hidden_size);
|
||||
});
|
||||
}
|
||||
|
||||
void dynamic_scaled_int8_quant(
|
||||
torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor const& input, // [..., hidden_size]
|
||||
torch::Tensor& scales) {
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
|
||||
int const hidden_size = input.size(-1);
|
||||
int const num_tokens = input.numel() / hidden_size;
|
||||
dim3 const grid(num_tokens);
|
||||
dim3 const block(std::min(hidden_size, 1024));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] {
|
||||
vllm::dynamic_scaled_int8_quant_kernel<scalar_t, float>
|
||||
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(),
|
||||
out.data_ptr<int8_t>(),
|
||||
scales.data_ptr<float>(), hidden_size);
|
||||
});
|
||||
}
|
||||
|
@ -21,29 +21,47 @@
|
||||
#include "cuda_compat.h"
|
||||
|
||||
namespace vllm {
|
||||
template <typename T, int numLanes = WARP_SIZE>
|
||||
__inline__ __device__ T warpReduceSum(T val) {
|
||||
static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
|
||||
"numLanes is not a positive power of 2!");
|
||||
static_assert(numLanes <= WARP_SIZE);
|
||||
#pragma unroll
|
||||
for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
|
||||
val += VLLM_SHFL_XOR_SYNC(val, mask);
|
||||
return val;
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename T>
|
||||
__inline__ __device__ T _max(T a, T b) {
|
||||
return max(a, b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__inline__ __device__ T _sum(T a, T b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename T>
|
||||
using ReduceFnType = T (*)(T, T);
|
||||
|
||||
// Helper function to return the next largest power of 2
|
||||
static constexpr int _nextPow2(unsigned int num) {
|
||||
if (num <= 1) return num;
|
||||
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
|
||||
}
|
||||
|
||||
/* Calculate the sum of all elements in a block */
|
||||
template <typename T, int numLanes = WARP_SIZE>
|
||||
__inline__ __device__ T warpReduce(T val, ReduceFnType<T> fn) {
|
||||
static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
|
||||
"numLanes is not a positive power of 2!");
|
||||
static_assert(numLanes <= WARP_SIZE);
|
||||
#pragma unroll
|
||||
for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
|
||||
val = fn(val, VLLM_SHFL_XOR_SYNC(val, mask));
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T, int maxBlockSize = 1024>
|
||||
__inline__ __device__ T blockReduceSum(T val) {
|
||||
__inline__ __device__ T blockReduce(T val, ReduceFnType<T> fn) {
|
||||
static_assert(maxBlockSize <= 1024);
|
||||
if constexpr (maxBlockSize > WARP_SIZE) {
|
||||
val = warpReduceSum<T>(val);
|
||||
val = warpReduce<T>(val, fn);
|
||||
// Calculates max number of lanes that need to participate in the last
|
||||
// warpReduce
|
||||
constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE;
|
||||
@ -56,12 +74,22 @@ __inline__ __device__ T blockReduceSum(T val) {
|
||||
|
||||
val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane]
|
||||
: (T)(0.0f);
|
||||
val = warpReduceSum<T, _nextPow2(maxActiveLanes)>(val);
|
||||
val = warpReduce<T, _nextPow2(maxActiveLanes)>(val, fn);
|
||||
} else {
|
||||
// A single warpReduce is equal to blockReduce
|
||||
val = warpReduceSum<T, _nextPow2(maxBlockSize)>(val);
|
||||
val = warpReduce<T, _nextPow2(maxBlockSize)>(val, fn);
|
||||
}
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T, int maxBlockSize = 1024>
|
||||
__inline__ __device__ T blockReduceMax(T val) {
|
||||
return blockReduce<T, maxBlockSize>(val, detail::_max<T>);
|
||||
}
|
||||
|
||||
template <typename T, int maxBlockSize = 1024>
|
||||
__inline__ __device__ T blockReduceSum(T val) {
|
||||
return blockReduce<T, maxBlockSize>(val, detail::_sum<T>);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
@ -4,7 +4,8 @@ import torch
|
||||
from vllm._C import ops
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 8192] # Arbitrary values for testing
|
||||
HIDDEN_SIZES = [16, 67, 768, 2048, 5120, 5137, 8192,
|
||||
8193] # Arbitrary values for testing
|
||||
NUM_TOKENS = [1, 7, 83, 4096] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
SCALE = [0.1, 0.5, 0.8, 1.2, 2.1]
|
||||
@ -14,17 +15,48 @@ SCALE = [0.1, 0.5, 0.8, 1.2, 2.1]
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("scale", SCALE)
|
||||
@torch.inference_mode()
|
||||
def test_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype,
|
||||
seed: int, scale: float) -> None:
|
||||
def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
|
||||
|
||||
out1 = (x / scale).round().clamp(
|
||||
torch.iinfo(torch.int8).min,
|
||||
torch.iinfo(torch.int8).max).to(torch.int8)
|
||||
x_token_max, _ = x.max(dim=1)
|
||||
x_token_max = x_token_max.to(dtype=torch.float32)
|
||||
scales = (x_token_max / float(127.0))[:, None].to(device="cuda",
|
||||
dtype=torch.float32)
|
||||
torch_out = (x / scales).round().clamp(int8_traits.min,
|
||||
int8_traits.max).to(torch.int8)
|
||||
|
||||
ops_out = torch.empty_like(x, dtype=torch.int8, device="cuda")
|
||||
scales_out = torch.empty_like(scales, dtype=torch.float32, device="cuda")
|
||||
ops.dynamic_scaled_int8_quant(ops_out, x, scales_out)
|
||||
|
||||
assert torch.allclose(scales_out, scales)
|
||||
assert torch.allclose(torch_out, ops_out,
|
||||
atol=1) # big atol to account for rounding errors
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("scale", SCALE)
|
||||
@torch.inference_mode()
|
||||
def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
||||
dtype: torch.dtype, seed: int,
|
||||
scale: float) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
|
||||
x = torch.rand(num_tokens, hidden_size, dtype=dtype, device="cuda") * 1000
|
||||
|
||||
out1 = (x / scale).round().clamp(int8_traits.min,
|
||||
int8_traits.max).to(torch.int8)
|
||||
out2 = torch.empty_like(x, dtype=torch.int8)
|
||||
scale_argument = torch.tensor([scale], dtype=torch.float32, device="cuda")
|
||||
|
||||
|
@ -6,7 +6,8 @@ Run `pytest tests/quantization/test_compressed_tensors.py`.
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensorsLinearMethod, CompressedTensorsW8A8StaticTensor)
|
||||
CompressedTensorsLinearMethod, CompressedTensorsW8A8DynamicToken,
|
||||
CompressedTensorsW8A8StaticTensor)
|
||||
|
||||
|
||||
def test_compressed_tensors_w8a8_static_setup(vllm_runner):
|
||||
@ -34,3 +35,19 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner):
|
||||
assert qkv_proj.weight_scale.shard_splitter is not None
|
||||
assert qkv_proj.weight_scale.logical_widths is not None
|
||||
assert qkv_proj.input_scale.dtype is torch.float32
|
||||
|
||||
|
||||
def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner):
|
||||
model_path = "nm-testing/tinyllama-one-shot-dynamic-test"
|
||||
llm = vllm_runner(model_path,
|
||||
quantization="sparseml",
|
||||
enforce_eager=True,
|
||||
dtype=torch.float16)
|
||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model
|
||||
layer = model.model.layers[0]
|
||||
|
||||
qkv_proj = layer.self_attn.qkv_proj
|
||||
|
||||
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
||||
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8DynamicToken)
|
||||
assert qkv_proj.weight.dtype is torch.int8
|
||||
|
@ -266,21 +266,33 @@ def scaled_fp8_quant(
|
||||
|
||||
|
||||
# int8
|
||||
def static_scaled_int8_quant(input: torch.Tensor,
|
||||
scale: torch.Tensor) -> torch.Tensor:
|
||||
def scaled_int8_quant(
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantize the input tensor to int8 and return the quantized tensor.
|
||||
Quantize the input tensor to int8 and return the quantized tensor and scale.
|
||||
|
||||
Args:
|
||||
input: The input tensor to be quantized to int8.
|
||||
scale: Scaling factor for the int8 quantization.
|
||||
scale: Optional scaling factor for the int8 quantization.
|
||||
When not provided, we invoke dynamic-per-token quantization.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor in int8.
|
||||
Tuple[Torch.Tensor, Torch.Tensor] : Output int8 tensor and scales.
|
||||
"""
|
||||
q = torch.empty_like(input, dtype=torch.int8)
|
||||
vllm_ops.static_scaled_int8_quant(q, input, scale)
|
||||
return q
|
||||
output = torch.empty_like(input, dtype=torch.int8)
|
||||
if scale is not None:
|
||||
# static-per-tensor quantization.
|
||||
vllm_ops.static_scaled_int8_quant(output, input, scale)
|
||||
return output, scale
|
||||
|
||||
# dynamic-per-token quantization.
|
||||
input_scales = torch.empty((input.numel() // input.shape[-1], 1),
|
||||
device=input.device,
|
||||
dtype=torch.float32)
|
||||
vllm_ops.dynamic_scaled_int8_quant(output, input, input_scales)
|
||||
return output, input_scales
|
||||
|
||||
|
||||
# moe
|
||||
|
@ -1,12 +1,16 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme, CompressedTensorsW8A8StaticTensor)
|
||||
CompressedTensorsScheme, CompressedTensorsW8A8DynamicToken,
|
||||
CompressedTensorsW8A8StaticTensor)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
QuantizationArgs, QuantizationStrategy, find_first_name_or_class_match)
|
||||
|
||||
|
||||
class CompressedTensorsConfig(QuantizationConfig):
|
||||
@ -47,10 +51,12 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
targets = quant_config.get("targets")
|
||||
for target in targets:
|
||||
layer_quant_details[target] = {}
|
||||
layer_quant_details[target]["weight"] = quant_config.get(
|
||||
"weights")
|
||||
layer_quant_details[target]["input"] = quant_config.get(
|
||||
"input_activations")
|
||||
layer_quant_details[target][
|
||||
"weight"] = QuantizationArgs.parse_obj(
|
||||
quant_config.get("weights"))
|
||||
layer_quant_details[target][
|
||||
"input"] = QuantizationArgs.parse_obj(
|
||||
quant_config.get("input_activations"))
|
||||
|
||||
return cls(layer_quant_details=layer_quant_details, ignore=ignore)
|
||||
|
||||
@ -58,40 +64,46 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return []
|
||||
|
||||
def _get_schema(self, weight_quant: Dict, input_quant: Dict):
|
||||
# TODO: Refactor as additional cases are supported
|
||||
def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
|
||||
is_tensor = (weight_quant.strategy == input_quant.strategy ==
|
||||
QuantizationStrategy.TENSOR.value)
|
||||
is_symmetric = weight_quant.symmetric and input_quant.symmetric
|
||||
is_static = not weight_quant.dynamic and not input_quant.dynamic
|
||||
|
||||
weight_bit = weight_quant.get("num_bits")
|
||||
input_bit = input_quant.get("num_bits")
|
||||
return is_8_bits and is_tensor and is_symmetric and is_static
|
||||
|
||||
weight_strategy = weight_quant.get("strategy")
|
||||
input_strategy = input_quant.get("strategy")
|
||||
def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
|
||||
is_token_tensor = (weight_quant.strategy
|
||||
== QuantizationStrategy.TENSOR.value) and (
|
||||
input_quant.strategy
|
||||
== QuantizationStrategy.TOKEN.value)
|
||||
is_symmetric = weight_quant.symmetric and input_quant.symmetric
|
||||
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
|
||||
|
||||
weight_symmetric = weight_quant.get("symmetric")
|
||||
input_symmetric = input_quant.get("symmetric")
|
||||
return is_8_bits and is_token_tensor and is_symmetric and is_dynamic
|
||||
|
||||
is_8_bits = weight_bit == input_bit == 8
|
||||
is_tensor = weight_strategy == input_strategy == "tensor"
|
||||
is_symmetric = weight_symmetric and input_symmetric
|
||||
|
||||
if is_8_bits and is_tensor and is_symmetric and \
|
||||
torch.cuda.is_available():
|
||||
# CompressedTensorsW8A8StaticTensor only supports CUDA path for
|
||||
# now.
|
||||
def _get_schema(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> "CompressedTensorsScheme":
|
||||
if self._is_static_tensor_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8StaticTensor()
|
||||
raise NotImplementedError(
|
||||
"Scheme not supported. Only CUDA, 8-bit static symmtetric "
|
||||
"per tensor quantization is currently supported")
|
||||
|
||||
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8DynamicToken()
|
||||
|
||||
raise NotImplementedError("Scheme not supported.")
|
||||
|
||||
def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme":
|
||||
|
||||
# TODO: update with matching function from `compressed_tensors`
|
||||
layer_type_name = None
|
||||
layer_name_class = type(layer).__name__.lower()
|
||||
for target in self.layer_quant_details:
|
||||
if target.lower() in layer_name_class:
|
||||
layer_type_name = target
|
||||
break
|
||||
layer_type_name = find_first_name_or_class_match(
|
||||
name="",
|
||||
module=layer,
|
||||
targets=self.layer_quant_details.keys(),
|
||||
check_contains=True)
|
||||
|
||||
if layer_type_name is None:
|
||||
raise ValueError(f"Could not matching target for layer {layer}")
|
||||
|
||||
@ -117,7 +129,9 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
|
||||
**extra_weight_attrs):
|
||||
"""
|
||||
Use the CompressedTensorsScheme associated with each layer to create
|
||||
the necessary parameters for the layer.
|
||||
the necessary parameters for the layer. See LinearMethodBase for param
|
||||
details
|
||||
|
||||
"""
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
@ -139,7 +153,8 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
|
||||
"""
|
||||
Use the output of create_weights and the CompressedTensorsScheme
|
||||
associated with the layer to apply the forward pass with the
|
||||
layer input.
|
||||
layer input. See LinearMethodBase for param details
|
||||
|
||||
"""
|
||||
|
||||
if bias is not None:
|
||||
|
@ -1,5 +1,7 @@
|
||||
from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401
|
||||
from .compressed_tensors_unquantized import ( # noqa: F401
|
||||
CompressedTensorsUnquantized)
|
||||
from .compressed_tensors_w8a8_dynamictoken import ( # noqa: F401, E501
|
||||
CompressedTensorsW8A8DynamicToken)
|
||||
from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501
|
||||
CompressedTensorsW8A8StaticTensor)
|
||||
|
@ -0,0 +1,85 @@
|
||||
from typing import Callable, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm import _custom_ops as custom_ops
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
__all__ = ["CompressedTensorsW8A8DynamicToken"]
|
||||
|
||||
|
||||
class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):
|
||||
|
||||
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
|
||||
if isinstance(shard_id, int):
|
||||
return shard_id
|
||||
|
||||
assert isinstance(shard_id, str)
|
||||
qkv_idxs = {"q": 0, "k": 1, "v": 2}
|
||||
assert shard_id in qkv_idxs
|
||||
return qkv_idxs[shard_id]
|
||||
|
||||
def scales_shard_splitter(
|
||||
self, param: torch.Tensor, loaded_weight: torch.Tensor,
|
||||
shard_id: Union[str, int],
|
||||
logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
shard_id = self._shard_id_as_int(shard_id)
|
||||
offset = sum(logical_widths[:shard_id])
|
||||
size = logical_widths[shard_id]
|
||||
# update loaded weight with copies for broadcast.
|
||||
loaded_weight = loaded_weight.repeat(size)
|
||||
return param[offset:offset + size], loaded_weight
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: List[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
# When the scales have a single value, it is required that they be
|
||||
# on the CPU for performance and CUDA Graphs compatibility. Please
|
||||
# refer to the comment in
|
||||
# CompressedTensorsW8A8StaticTensor::create_weights for further
|
||||
# information.
|
||||
is_tensor_partitioned = len(output_partition_sizes) != 1
|
||||
weight_scale_dim = sum(
|
||||
output_partition_sizes) if is_tensor_partitioned else 1
|
||||
|
||||
weight_zero_point = Parameter(torch.empty(1, dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
|
||||
weight_scale = Parameter(torch.empty(weight_scale_dim,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
|
||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
set_weight_attrs(weight, {"weight_loader": weight_loader})
|
||||
set_weight_attrs(weight, {"logical_widths": output_partition_sizes})
|
||||
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
|
||||
set_weight_attrs(
|
||||
weight_scale, {
|
||||
"shard_splitter": self.scales_shard_splitter,
|
||||
"logical_widths": output_partition_sizes
|
||||
})
|
||||
|
||||
layer.register_parameter("weight_zero_point", weight_zero_point)
|
||||
set_weight_attrs(weight_zero_point, {"weight_loader": weight_loader})
|
||||
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
|
||||
x_q, input_scales = custom_ops.scaled_int8_quant(x)
|
||||
return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), input_scales,
|
||||
weight_scale, x.dtype)
|
@ -97,7 +97,7 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
|
||||
act_scale = layer.input_scale
|
||||
|
||||
# Input quantize
|
||||
x_q = custom_ops.static_scaled_int8_quant(x, act_scale)
|
||||
x_q, _ = custom_ops.scaled_int8_quant(x, act_scale)
|
||||
|
||||
return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale,
|
||||
weight_scale, x.dtype)
|
||||
|
@ -0,0 +1,114 @@
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Iterable, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from torch.nn import Module
|
||||
|
||||
|
||||
class QuantizationType(str, Enum):
|
||||
"""
|
||||
Enum storing quantization type options
|
||||
"""
|
||||
|
||||
INT = "int"
|
||||
FLOAT = "float"
|
||||
|
||||
|
||||
class QuantizationStrategy(str, Enum):
|
||||
"""
|
||||
Enum storing quantization strategy options
|
||||
"""
|
||||
|
||||
TENSOR = "tensor"
|
||||
CHANNEL = "channel"
|
||||
GROUP = "group"
|
||||
BLOCK = "block"
|
||||
TOKEN = "token"
|
||||
|
||||
|
||||
class QuantizationArgs(BaseModel):
|
||||
"""
|
||||
User facing arguments used to define a quantization config
|
||||
for weights or activations
|
||||
|
||||
:param num_bits: quantization bit depth
|
||||
:param type: dtype to quantized to, either int or float
|
||||
:param symmetric: whether or not quantization scale is symmetric
|
||||
:param strategy: string determining the scope of scale/zero-point to apply
|
||||
:param group_size: group length to use for the group strategy
|
||||
:param block_structure: 2d block structure to use for the block
|
||||
strategy, must be of the format "2x4", "8x16", etc.
|
||||
:param dynamic: set True to perform dynamic quantization -
|
||||
values will not be calibrated during calibration phase,
|
||||
instead during inference new quantization ranges will be
|
||||
observed with every sample. Defaults to False for static
|
||||
quantization. Note that enabling dynamic quantization
|
||||
will change the default observer to a memoryless one
|
||||
"""
|
||||
|
||||
num_bits: int = 8
|
||||
type: QuantizationType = QuantizationType.INT
|
||||
symmetric: bool = True
|
||||
group_size: Optional[int] = None
|
||||
strategy: Optional[QuantizationStrategy] = None
|
||||
block_structure: Optional[str] = None
|
||||
dynamic: bool = False
|
||||
observer: str = Field(
|
||||
default="minmax",
|
||||
description=("The class to use to compute the quantization param - "
|
||||
"scale and zero-point'"),
|
||||
)
|
||||
observer_kwargs: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description=
|
||||
("optional dict of kwargs to be passed directly to torch quantization "
|
||||
"Observers constructor excluding quantization range or symmetry"),
|
||||
)
|
||||
|
||||
|
||||
def find_first_name_or_class_match(
|
||||
name: str,
|
||||
module: Module,
|
||||
targets: Iterable[str],
|
||||
check_contains: bool = False) -> Optional[str]:
|
||||
"""
|
||||
Helper function to map the quantization details listed in the config
|
||||
for a given list of targets against each model layer. First uses the
|
||||
layer name to try and find a match. If no name match is found, uses
|
||||
the layer class name. Returns None otherwise.
|
||||
|
||||
:param name: layer name
|
||||
:param module: torch.nn.Module
|
||||
:param targets: list of targets to match the layer against
|
||||
:param check_contains: whether or not to do a substring match
|
||||
"""
|
||||
|
||||
return _find_first_match(name, targets) or _find_first_match(
|
||||
module.__class__.__name__, targets, check_contains)
|
||||
|
||||
|
||||
def _find_first_match(value: str,
|
||||
targets: Iterable[str],
|
||||
check_contains: bool = False) -> Optional[str]:
|
||||
"""
|
||||
Returns first element of target that matches value either
|
||||
exactly or as a regex after 're:'. If check_contains is set to True,
|
||||
additionally checks if the target string is contained within the value.
|
||||
|
||||
:param value: string to compare the list of targets against
|
||||
:param targets: list of targets to match the layer against
|
||||
:param check_contains: whether or not to do a substring match
|
||||
"""
|
||||
|
||||
for target in targets:
|
||||
if target.startswith("re:"):
|
||||
pattern = target[3:]
|
||||
if re.match(pattern, value):
|
||||
return target
|
||||
elif check_contains:
|
||||
if target.lower() in value.lower():
|
||||
return target
|
||||
elif target == value:
|
||||
return target
|
||||
return None
|
Loading…
x
Reference in New Issue
Block a user