Add minimum capability requirement for AWQ (#1064)

This commit is contained in:
Woosuk Kwon 2023-09-18 12:02:01 -07:00 committed by GitHub
parent cc796b1358
commit 2b1c116b5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 47 additions and 2 deletions

View File

@ -11,9 +11,14 @@ Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransfor
#pragma once
namespace vllm {
namespace awq {
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
uint4 result;
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
@ -75,5 +80,8 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
return result;
#endif
}
} // namespace awq
} // namespace vllm

View File

@ -16,6 +16,9 @@ Adapted from https://github.com/mit-han-lab/llm-awq
#include <cuda_fp16.h>
namespace vllm {
namespace awq {
// Pack two half values.
static inline __device__ __host__ unsigned
__pack_half2(const half x, const half y) {
@ -26,6 +29,9 @@ __pack_half2(const half x, const half y) {
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
static constexpr uint32_t ZERO = 0x0;
float C_warp[32];
__shared__ half A_shared[16 * (32 + 8)];
@ -214,11 +220,15 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
}
}
}
#endif
}
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
static constexpr uint32_t ZERO = 0x0;
float C_warp[32];
__shared__ half A_shared[16 * (32 + 8)];
@ -412,8 +422,12 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
}
}
}
#endif
}
} // namespace awq
} // namespace vllm
// in_feats: M, IC [float16]
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
// scaling_factors: IC // G, OC [float16]
@ -459,7 +473,7 @@ torch::Tensor awq_gemm(
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2);
gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
vllm::awq::gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
}
else if (num_out_channels % 64 == 0)
@ -470,7 +484,7 @@ torch::Tensor awq_gemm(
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2);
gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>(
vllm::awq::gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>(
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
}
return _out_feats.sum(0);

View File

@ -68,6 +68,14 @@ def get_model(model_config: ModelConfig) -> nn.Module:
quant_config = get_quant_config(model_config.quantization,
model_config.model,
model_config.download_dir)
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < quant_config.get_min_capability():
raise ValueError(
f"The quantization method {model_config.quantization} is not "
"supported for the current GPU. "
f"Minimum capability: {quant_config.get_min_capability()}. "
f"Current capability: {capability}.")
supported_dtypes = quant_config.get_supported_act_dtypes()
if model_config.dtype not in supported_dtypes:
raise ValueError(

View File

@ -40,6 +40,11 @@ class AWQConfig(QuantizationConfig):
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]
@classmethod
def get_min_capability(cls) -> int:
# The AWQ kernel only supports Ampere or newer GPUs.
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
return [

View File

@ -15,6 +15,16 @@ class QuantizationConfig:
"""List of supported activation dtypes."""
raise NotImplementedError
@classmethod
def get_min_capability(cls) -> int:
"""Minimum GPU capability to support the quantization method.
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
This requirement is due to the custom CUDA kernels used by the
quantization method.
"""
raise NotImplementedError
@classmethod
def get_config_filenames(cls) -> List[str]:
"""List of filenames to search for in the model directory."""