Add minimum capability requirement for AWQ (#1064)
This commit is contained in:
parent
cc796b1358
commit
2b1c116b5a
@ -11,9 +11,14 @@ Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransfor
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace vllm {
|
||||
namespace awq {
|
||||
|
||||
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
uint4 result;
|
||||
|
||||
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
||||
@ -75,5 +80,8 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace awq
|
||||
} // namespace vllm
|
||||
|
@ -16,6 +16,9 @@ Adapted from https://github.com/mit-han-lab/llm-awq
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
namespace vllm {
|
||||
namespace awq {
|
||||
|
||||
// Pack two half values.
|
||||
static inline __device__ __host__ unsigned
|
||||
__pack_half2(const half x, const half y) {
|
||||
@ -26,6 +29,9 @@ __pack_half2(const half x, const half y) {
|
||||
|
||||
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
static constexpr uint32_t ZERO = 0x0;
|
||||
float C_warp[32];
|
||||
__shared__ half A_shared[16 * (32 + 8)];
|
||||
@ -214,11 +220,15 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
static constexpr uint32_t ZERO = 0x0;
|
||||
float C_warp[32];
|
||||
__shared__ half A_shared[16 * (32 + 8)];
|
||||
@ -412,8 +422,12 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace awq
|
||||
} // namespace vllm
|
||||
|
||||
// in_feats: M, IC [float16]
|
||||
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
|
||||
// scaling_factors: IC // G, OC [float16]
|
||||
@ -459,7 +473,7 @@ torch::Tensor awq_gemm(
|
||||
// threadIdx.x: 32
|
||||
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||
dim3 threads_per_block(32, 2);
|
||||
gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
|
||||
vllm::awq::gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
|
||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
|
||||
}
|
||||
else if (num_out_channels % 64 == 0)
|
||||
@ -470,7 +484,7 @@ torch::Tensor awq_gemm(
|
||||
// threadIdx.x: 32
|
||||
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||
dim3 threads_per_block(32, 2);
|
||||
gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>(
|
||||
vllm::awq::gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>(
|
||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
|
||||
}
|
||||
return _out_feats.sum(0);
|
||||
|
@ -68,6 +68,14 @@ def get_model(model_config: ModelConfig) -> nn.Module:
|
||||
quant_config = get_quant_config(model_config.quantization,
|
||||
model_config.model,
|
||||
model_config.download_dir)
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
if capability < quant_config.get_min_capability():
|
||||
raise ValueError(
|
||||
f"The quantization method {model_config.quantization} is not "
|
||||
"supported for the current GPU. "
|
||||
f"Minimum capability: {quant_config.get_min_capability()}. "
|
||||
f"Current capability: {capability}.")
|
||||
supported_dtypes = quant_config.get_supported_act_dtypes()
|
||||
if model_config.dtype not in supported_dtypes:
|
||||
raise ValueError(
|
||||
|
@ -40,6 +40,11 @@ class AWQConfig(QuantizationConfig):
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# The AWQ kernel only supports Ampere or newer GPUs.
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return [
|
||||
|
@ -15,6 +15,16 @@ class QuantizationConfig:
|
||||
"""List of supported activation dtypes."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
"""Minimum GPU capability to support the quantization method.
|
||||
|
||||
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
|
||||
This requirement is due to the custom CUDA kernels used by the
|
||||
quantization method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
"""List of filenames to search for in the model directory."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user