From e3e79e9e8a2224e03a711c3d1ef7a35daa447083 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sat, 16 Sep 2023 00:03:37 -0700 Subject: [PATCH] Implement AWQ quantization support for LLaMA (#1032) Co-authored-by: Robert Irvine Co-authored-by: root Co-authored-by: Casper Co-authored-by: julian-q --- .gitignore | 4 + benchmarks/benchmark_latency.py | 18 +- benchmarks/benchmark_throughput.py | 74 +-- csrc/quantization.cpp | 15 + csrc/quantization/awq/dequantize.cuh | 79 +++ csrc/quantization/awq/gemm_kernels.cu | 477 ++++++++++++++++++ setup.py | 14 + vllm/config.py | 16 + vllm/engine/arg_utils.py | 12 +- vllm/engine/llm_engine.py | 1 + .../layers/quantized_linear/__init__.py | 37 ++ .../layers/quantized_linear/awq.py | 102 ++++ vllm/model_executor/model_loader.py | 30 +- vllm/model_executor/models/llama.py | 125 +++-- .../parallel_utils/tensor_parallel/layers.py | 177 ++----- .../quantization_utils/__init__.py | 20 + vllm/model_executor/quantization_utils/awq.py | 67 +++ .../model_executor/quantization_utils/base.py | 65 +++ vllm/model_executor/weight_utils.py | 53 +- 19 files changed, 1178 insertions(+), 208 deletions(-) create mode 100644 csrc/quantization.cpp create mode 100644 csrc/quantization/awq/dequantize.cuh create mode 100644 csrc/quantization/awq/gemm_kernels.cu create mode 100644 vllm/model_executor/layers/quantized_linear/__init__.py create mode 100644 vllm/model_executor/layers/quantized_linear/awq.py create mode 100644 vllm/model_executor/quantization_utils/__init__.py create mode 100644 vllm/model_executor/quantization_utils/awq.py create mode 100644 vllm/model_executor/quantization_utils/base.py diff --git a/.gitignore b/.gitignore index da5a337c..b531b791 100644 --- a/.gitignore +++ b/.gitignore @@ -173,3 +173,7 @@ cython_debug/ # Sphinx documentation _build/ + +# vim swap files +*.swo +*.swp diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 8269481e..be50a5f4 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -18,6 +18,7 @@ def main(args: argparse.Namespace): llm = LLM( model=args.model, tokenizer=args.tokenizer, + quantization=args.quantization, tensor_parallel_size=args.tensor_parallel_size, max_num_seqs=args.batch_size, max_num_batched_tokens=args.batch_size * args.input_len, @@ -63,19 +64,28 @@ def main(args: argparse.Namespace): if __name__ == '__main__': parser = argparse.ArgumentParser( description='Benchmark the latency of processing a single batch of ' - 'requests till completion.') + 'requests till completion.') parser.add_argument('--model', type=str, default='facebook/opt-125m') parser.add_argument('--tokenizer', type=str, default=None) + parser.add_argument('--quantization', + '-q', + choices=['awq', None], + default=None) parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) parser.add_argument('--input-len', type=int, default=32) parser.add_argument('--output-len', type=int, default=128) parser.add_argument('--batch-size', type=int, default=8) - parser.add_argument('--n', type=int, default=1, + parser.add_argument('--n', + type=int, + default=1, help='Number of generated sequences per prompt.') parser.add_argument('--use-beam-search', action='store_true') - parser.add_argument('--num-iters', type=int, default=3, + parser.add_argument('--num-iters', + type=int, + default=3, help='Number of iterations to run.') - parser.add_argument('--trust-remote-code', action='store_true', + parser.add_argument('--trust-remote-code', + action='store_true', help='trust remote code from huggingface') args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index b2bea852..c200deb6 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -3,7 +3,7 @@ import argparse import json import random import time -from typing import List, Tuple +from typing import List, Optional, Tuple import torch from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase @@ -22,15 +22,10 @@ def sample_requests( with open(dataset_path) as f: dataset = json.load(f) # Filter out the conversations with less than 2 turns. - dataset = [ - data for data in dataset - if len(data["conversations"]) >= 2 - ] + dataset = [data for data in dataset if len(data["conversations"]) >= 2] # Only keep the first two turns of each conversation. - dataset = [ - (data["conversations"][0]["value"], data["conversations"][1]["value"]) - for data in dataset - ] + dataset = [(data["conversations"][0]["value"], + data["conversations"][1]["value"]) for data in dataset] # Tokenize the prompts and completions. prompts = [prompt for prompt, _ in dataset] @@ -63,6 +58,7 @@ def run_vllm( requests: List[Tuple[str, int, int]], model: str, tokenizer: str, + quantization: Optional[str], tensor_parallel_size: int, seed: int, n: int, @@ -72,6 +68,7 @@ def run_vllm( llm = LLM( model=model, tokenizer=tokenizer, + quantization=quantization, tensor_parallel_size=tensor_parallel_size, seed=seed, trust_remote_code=trust_remote_code, @@ -111,8 +108,8 @@ def run_hf( trust_remote_code: bool, ) -> float: assert not use_beam_search - llm = AutoModelForCausalLM.from_pretrained(model, - torch_dtype=torch.float16, trust_remote_code=trust_remote_code) + llm = AutoModelForCausalLM.from_pretrained( + model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) if llm.config.model_type == "llama": # To enable padding in the HF backend. tokenizer.pad_token = tokenizer.eos_token @@ -132,13 +129,14 @@ def run_hf( if len(batch) < max_batch_size and i != len(requests) - 1: # Check if we can add more requests to the batch. _, next_prompt_len, next_output_len = requests[i + 1] - if (max(max_prompt_len, next_prompt_len) + max( - max_output_len, next_output_len)) <= 2048: + if (max(max_prompt_len, next_prompt_len) + + max(max_output_len, next_output_len)) <= 2048: # We can add more requests to the batch. continue # Generate the sequences. - input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids + input_ids = tokenizer(batch, return_tensors="pt", + padding=True).input_ids llm_outputs = llm.generate( input_ids=input_ids.cuda(), do_sample=not use_beam_search, @@ -165,44 +163,58 @@ def main(args: argparse.Namespace): random.seed(args.seed) # Sample the requests. - tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code) + tokenizer = get_tokenizer(args.tokenizer, + trust_remote_code=args.trust_remote_code) requests = sample_requests(args.dataset, args.num_prompts, tokenizer) if args.backend == "vllm": - elapsed_time = run_vllm( - requests, args.model, args.tokenizer, args.tensor_parallel_size, - args.seed, args.n, args.use_beam_search, args.trust_remote_code) + elapsed_time = run_vllm(requests, args.model, args.tokenizer, + args.quantization, args.tensor_parallel_size, + args.seed, args.n, args.use_beam_search, + args.trust_remote_code) elif args.backend == "hf": assert args.tensor_parallel_size == 1 - elapsed_time = run_hf( - requests, args.model, tokenizer, args.n, args.use_beam_search, - args.hf_max_batch_size, args.trust_remote_code) + elapsed_time = run_hf(requests, args.model, tokenizer, args.n, + args.use_beam_search, args.hf_max_batch_size, + args.trust_remote_code) else: raise ValueError(f"Unknown backend: {args.backend}") - total_num_tokens = sum( - prompt_len + output_len - for _, prompt_len, output_len in requests - ) + total_num_tokens = sum(prompt_len + output_len + for _, prompt_len, output_len in requests) print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " f"{total_num_tokens / elapsed_time:.2f} tokens/s") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Benchmark the throughput.") - parser.add_argument("--backend", type=str, choices=["vllm", "hf"], + parser.add_argument("--backend", + type=str, + choices=["vllm", "hf"], default="vllm") - parser.add_argument("--dataset", type=str, required=True, + parser.add_argument("--dataset", + type=str, + required=True, help="Path to the dataset.") parser.add_argument("--model", type=str, default="facebook/opt-125m") parser.add_argument("--tokenizer", type=str, default=None) + parser.add_argument('--quantization', + '-q', + choices=['awq', None], + default=None) parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) - parser.add_argument("--n", type=int, default=1, + parser.add_argument("--n", + type=int, + default=1, help="Number of generated sequences per prompt.") parser.add_argument("--use-beam-search", action="store_true") - parser.add_argument("--num-prompts", type=int, default=1000, + parser.add_argument("--num-prompts", + type=int, + default=1000, help="Number of prompts to process.") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--hf-max-batch-size", type=int, default=None, + parser.add_argument("--hf-max-batch-size", + type=int, + default=None, help="Maximum batch size for HF backend.") parser.add_argument('--trust-remote-code', action='store_true', @@ -215,6 +227,8 @@ if __name__ == "__main__": elif args.backend == "hf": if args.hf_max_batch_size is None: raise ValueError("HF max batch size is required for HF backend.") + if args.quantization is not None: + raise ValueError("Quantization is only for vLLM backend.") if args.tokenizer is None: args.tokenizer = args.model diff --git a/csrc/quantization.cpp b/csrc/quantization.cpp new file mode 100644 index 00000000..3afa7f6a --- /dev/null +++ b/csrc/quantization.cpp @@ -0,0 +1,15 @@ +#include + +torch::Tensor awq_gemm( + torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + int split_k_iters); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "awq_gemm", + &awq_gemm, + "Quantized GEMM for AWQ"); +} diff --git a/csrc/quantization/awq/dequantize.cuh b/csrc/quantization/awq/dequantize.cuh new file mode 100644 index 00000000..060a4ef9 --- /dev/null +++ b/csrc/quantization/awq/dequantize.cuh @@ -0,0 +1,79 @@ +/* +Adapted from https://github.com/mit-han-lab/llm-awq +Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +@article{lin2023awq, + title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, + author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, + journal={arXiv}, + year={2023} +} +*/ + +#pragma once + + +__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) +{ + uint4 result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM_MASK = 0x000f000f; + static constexpr uint32_t TOP_MASK = 0x00f000f0; + static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; + + // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing + // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. + // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and + // elt_67 to fp16 without having to shift them to the bottom bits before hand. + + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue + // immediately before required. + const uint32_t top_i4s = i4s >> 8; + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + + // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the + // half2 ctor. In this case, I chose performance reliability over code readability. + + // This is the half2 {1032, 1032} represented as an integer. + // static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; + // Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7] + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400; + // This is the half2 {1 / 16, 1 / 16} represented as an integer. + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + // This is the half2 {-72, -72} represented as an integer. + // static constexpr uint32_t NEG_72 = 0xd480d480; + // Haotian: Let's use {-64, -64}. + static constexpr uint32_t NEG_64 = 0xd400d400; + + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); + // Convert elt_45 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64)); + + return result; +} + diff --git a/csrc/quantization/awq/gemm_kernels.cu b/csrc/quantization/awq/gemm_kernels.cu new file mode 100644 index 00000000..895845a3 --- /dev/null +++ b/csrc/quantization/awq/gemm_kernels.cu @@ -0,0 +1,477 @@ +/* +Adapted from https://github.com/mit-han-lab/llm-awq +@article{lin2023awq, + title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, + author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, + journal={arXiv}, + year={2023} +} + */ + + +#include +#include + +#include "dequantize.cuh" + +#include + +// Pack two half values. +static inline __device__ __host__ unsigned +__pack_half2(const half x, const half y) { + unsigned v0 = *((unsigned short *)&x); + unsigned v1 = *((unsigned short *)&y); + return (v1 << 16) | v0; +} + +__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) +{ + static constexpr uint32_t ZERO = 0x0; + float C_warp[32]; + __shared__ half A_shared[16 * (32 + 8)]; + __shared__ half B_shared[32 * (128 + 8)]; + + __shared__ half scaling_factors_shared[128]; + __shared__ half zeros_shared[128]; + + int j_factors1 = ((OC + 128 - 1) / 128); + int blockIdx_x = 0; + int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1); + int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1); + + half A_shared_warp[8]; + half B_shared_warp[32]; + for (int j_0_4_init = 0; j_0_4_init < 4; ++j_0_4_init) { + for (int i = 0; i < 8; ++i) { + C_warp[(j_0_4_init * 8) + i] = 0.0; + } + } + + static constexpr int row_stride_warp = 32 * 8 / 32; + static constexpr int row_stride = 2 * 32 * 8 / 128; + bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 128; + // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 + bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id + // bool wb_C_flag = (threadIdx.x / 4) < M; + + half* A_ptr = A + + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC + + (((int)threadIdx.x) % (32 / 8)) * 8; + + int* B_ptr = B + + ((int)threadIdx.y) * (OC / 8) * 2 + + (((int)threadIdx.x) / (128 / 8)) * (OC / 8) + + (((int)blockIdx_y) % j_factors1) * (128 / 8) + + (((int)threadIdx.x) % (128 / 8)) * 1; +// Why * 1 in the above line? + + half* A_shared_ptr = A_shared + + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + + (((int)threadIdx.x) % (32 / 8) ) * 8; + + half* B_shared_ptr = B_shared + + ((int)threadIdx.y) * (row_stride / 2) * (128 + 8) + + (((int)threadIdx.x) / (128 / 8)) * (128 + 8) + + (((int)threadIdx.x) % (128 / 8)) * 8; + + int* zeros_ptr = zeros + + (((int)blockIdx_y) % j_factors1) * (128 / 8) + + ((int)threadIdx.x) % (128 / 8); + + half* scaling_factors_ptr = scaling_factors + + (((int)blockIdx_y) % j_factors1) * (128) + + (((int)threadIdx.x) % (128 / 8)) * 8; + + half* C_ptr = C + + blockIdx_z * M * OC // blockIdz.x -> split_k dim + + (((int)blockIdx_y) % j_factors1) * 128 + + ((int)threadIdx.y) * 64 + + (((int)threadIdx.x) % 4) * 2; + + // preload s.f. and zeros + int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; + if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1; + for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) { + int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; + __syncthreads(); + // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 + if (ld_A_flag) + { + *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); + } + else + { + *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); + } + + // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { + uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); + uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); + uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); + /* + if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){ + printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); + } + */ + // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); + int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); + + for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) { + + // B: 32 x 136 (128+8) float16 + // each warp: 32 x 4 + // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 + // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8))); + // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) + uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); + uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); + //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8); + + // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8); + // - zero and * scale + // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + /* + if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){ + printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); + } + */ + + // write back + *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (128 + 8)) = B_loaded_fp16; + } + __syncthreads(); + + for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) { + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) + ); + + + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) + : "r"(addr) + ); + } + + for (int ax1_0 = 0; ax1_0 < 4; ++ax1_0) { + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)((&(B_shared[(((k_0_1 * 2176) + (((int)threadIdx.y) * 64)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 136) + ((((int)threadIdx.x) >> 4) * 8)))) + ); + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) + : "r"(addr) + ); + } + } + for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) { + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + } + } + } + } + +// TODO: Shang: Hoist loop invariance. + for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) { + for (int local_id = 0; local_id < 8; ++local_id) { + int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; + if (row_offset < M) + { + *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); + } + } + } +} + + +__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) +{ + static constexpr uint32_t ZERO = 0x0; + float C_warp[32]; + __shared__ half A_shared[16 * (32 + 8)]; + __shared__ half B_shared[32 * (64 + 8)]; + + __shared__ half scaling_factors_shared[64]; + __shared__ half zeros_shared[64]; + + int j_factors1 = ((OC + 64 - 1) / 64); + + int blockIdx_x = 0; + int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1); + int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1); + + half A_shared_warp[8]; + half B_shared_warp[16]; + for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) { + for (int i = 0; i < 8; ++i) { + C_warp[(j_0_4_init * 8) + i] = 0.0; + } + } + + static constexpr int row_stride_warp = 32 * 8 / 32; + static constexpr int row_stride = 2 * 32 * 8 / 64; + bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64; + // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 + bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id + // bool wb_C_flag = (threadIdx.x / 4) < M; + + half* A_ptr = A + + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC + + (((int)threadIdx.x) % (32 / 8)) * 8; + + int* B_ptr = B + + ((int)threadIdx.y) * (OC / 8) * 4 + + (((int)threadIdx.x) / (64 / 8)) * (OC / 8) + + (((int)blockIdx_y) % j_factors1) * (64 / 8) + + (((int)threadIdx.x) % (64 / 8)) * 1; +// Why * 1 in the above line? + + half* A_shared_ptr = A_shared + + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + + (((int)threadIdx.x) % (32 / 8) ) * 8; + + half* B_shared_ptr = B_shared + + ((int)threadIdx.y) * (row_stride / 2) * (64 + 8) + + (((int)threadIdx.x) / (64 / 8)) * (64 + 8) + + (((int)threadIdx.x) % (64 / 8)) * 8; + + int* zeros_ptr = zeros + + (((int)blockIdx_y) % j_factors1) * (64 / 8) + + ((int)threadIdx.x) % (64 / 8); + + half* scaling_factors_ptr = scaling_factors + + (((int)blockIdx_y) % j_factors1) * (64) + + (((int)threadIdx.x) % (64 / 8)) * 8; + + half* C_ptr = C + + blockIdx_z * M * OC // blockIdz.x -> split_k dim + + (((int)blockIdx_y) % j_factors1) * 64 + + ((int)threadIdx.y) * 32 + + (((int)threadIdx.x) % 4) * 2; + + // preload s.f. and zeros + int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; + if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1; + for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) { + int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; + __syncthreads(); + // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 + if (ld_A_flag) + { + *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); + } + else + { + *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); + } + + // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { + uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); + uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); + uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); + /* + if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){ + printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); + } + */ + // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); + int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); + + for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) { + + // B: 32 x 136 (128+8) float16 + // each warp: 32 x 4 + // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 + // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8))); + // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) + uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); + uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); + //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8); + + // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8); + // - zero and * scale + // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + /* + if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){ + printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); + } + */ + + // write back + *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_fp16; + } + __syncthreads(); + + for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) + { + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8)))) + ); + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3]) + : "r"(addr) + ); + } + + + for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0) + { + { + unsigned int addr; + __asm__ __volatile__( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8)))) + ); + __asm__ __volatile__( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) + : "r"(addr) + ); + } + } + + for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4) + { + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); + } + + { + __asm__ __volatile__( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); + } + } + } + } + +// TODO: Shang: Hoist loop invariance. + for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) { + for (int local_id = 0; local_id < 8; ++local_id) { + int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; + if (row_offset < M) + { + *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); + } + } + } +} + +// in_feats: M, IC [float16] +// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b] +// scaling_factors: IC // G, OC [float16] +// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b] +// assume that batch_size < 16 for now + +torch::Tensor awq_gemm( + torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + int split_k_iters) +{ + int num_in_feats = _in_feats.size(0); + int num_in_channels = _in_feats.size(1); + const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); + + auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); + at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); + int num_out_feats = _out_feats.size(-2); + int num_out_channels = _out_feats.size(-1); + + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); + int group_size = num_in_channels / _scaling_factors.size(0); + + if (num_out_channels % 64 != 0) + throw std::invalid_argument("OC is not multiple of cta_N = 64"); + if (num_out_channels % 8 != 0) + throw std::invalid_argument("OC is not multiple of pack_num = 8"); + if (group_size % 32 != 0) + throw std::invalid_argument("Group size should be a multiple of 32"); + if (num_out_channels % group_size != 0) + throw std::invalid_argument("OC is not multiple of Group size"); + + if (num_out_channels % 128 == 0) + { + int j_factors1 = num_out_channels / 128 / 1; + dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + gemm_forward_4bit_cuda_m16n128k32<<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); + } + else if (num_out_channels % 64 == 0) + { + int j_factors1 = num_out_channels / 64 / 1; + dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); + + // threadIdx.x: 32 + // threadIdx.y: i_factors[2] * j_factors[2] + dim3 threads_per_block(32, 2); + gemm_forward_4bit_cuda_m16n64k32<<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); + } + return _out_feats.sum(0); +} diff --git a/setup.py b/setup.py index f1ee90f1..047ee8d0 100644 --- a/setup.py +++ b/setup.py @@ -146,6 +146,20 @@ activation_extension = CUDAExtension( ) ext_modules.append(activation_extension) +# Quantization kernels. +quantization_extension = CUDAExtension( + name="vllm.quantization_ops", + sources=[ + "csrc/quantization.cpp", + "csrc/quantization/awq/gemm_kernels.cu", + ], + extra_compile_args={ + "cxx": CXX_FLAGS, + "nvcc": NVCC_FLAGS, + }, +) +ext_modules.append(quantization_extension) + def get_path(*filepath) -> str: return os.path.join(ROOT_DIR, *filepath) diff --git a/vllm/config.py b/vllm/config.py index aa8c4dc3..dd92fbcc 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -43,6 +43,8 @@ class ModelConfig: version. max_model_len: Maximum length of a sequence (including prompt and output). If None, will be derived from the model. + quantization: Quantization method that was used to quantize the model + weights. If None, we assume the model weights are not quantized. """ def __init__( @@ -57,6 +59,7 @@ class ModelConfig: seed: int, revision: Optional[str], max_model_len: Optional[int] = None, + quantization: Optional[str] = None, ) -> None: self.model = model self.tokenizer = tokenizer @@ -66,11 +69,13 @@ class ModelConfig: self.load_format = load_format self.seed = seed self.revision = revision + self.quantization = quantization self.hf_config = get_config(model, trust_remote_code, revision) self.dtype = _get_and_verify_dtype(self.hf_config, dtype) self._verify_load_format() self._verify_tokenizer_mode() + self._verify_quantization() self.max_model_len = None if max_model_len is not None: derived_max_model_len = self.get_max_model_len() @@ -100,6 +105,17 @@ class ModelConfig: "either 'auto' or 'slow'.") self.tokenizer_mode = tokenizer_mode + def _verify_quantization(self) -> None: + supported_quantization = ["awq"] + if self.quantization is None: + return + quantization = self.quantization.lower() + if quantization not in supported_quantization: + raise ValueError( + f"Unknown quantization: {self.quantization}. Must be one of " + f"{supported_quantization}.") + self.quantization = quantization + def verify_with_parallel_config( self, parallel_config: "ParallelConfig", diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 9478e800..a03155a4 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -29,6 +29,7 @@ class EngineArgs: max_num_seqs: int = 256 disable_log_stats: bool = False revision: Optional[str] = None + quantization: Optional[str] = None def __post_init__(self): if self.tokenizer is None: @@ -88,7 +89,6 @@ class EngineArgs: 'a numpy cache to speed up the loading. ' '"dummy" will initialize the weights with random values, ' 'which is mainly for profiling.') - # TODO(woosuk): Support FP32. parser.add_argument( '--dtype', type=str, @@ -150,6 +150,13 @@ class EngineArgs: parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics') + # Quantization settings. + parser.add_argument('--quantization', + '-q', + type=str, + choices=['awq', None], + default=None, + help='Method used to quantize the weights') return parser @classmethod @@ -163,12 +170,11 @@ class EngineArgs: def create_engine_configs( self, ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]: - # Initialize the configs. model_config = ModelConfig(self.model, self.tokenizer, self.tokenizer_mode, self.trust_remote_code, self.download_dir, self.load_format, self.dtype, self.seed, self.revision, - self.max_model_len) + self.max_model_len, self.quantization) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1b0f50e3..859b6e15 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -80,6 +80,7 @@ class LLMEngine: f"download_dir={model_config.download_dir!r}, " f"load_format={model_config.load_format}, " f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " + f"quantization={model_config.quantization}, " f"seed={model_config.seed})") # TODO(woosuk): Print more configs in debug mode. diff --git a/vllm/model_executor/layers/quantized_linear/__init__.py b/vllm/model_executor/layers/quantized_linear/__init__.py new file mode 100644 index 00000000..bcb9a54e --- /dev/null +++ b/vllm/model_executor/layers/quantized_linear/__init__.py @@ -0,0 +1,37 @@ +from vllm.model_executor.layers.quantized_linear.awq import ( + AWQColumnParallelLinear, AWQRowParallelLinear) +from vllm.model_executor.parallel_utils.tensor_parallel import ( + ColumnParallelLinear, RowParallelLinear) + +_QUANTIZED_LINEAR_REGISTRY = { + "awq": (AWQColumnParallelLinear, AWQRowParallelLinear), +} + + +class ParallelLinear: + + @classmethod + def column(cls, *args, **kwargs) -> ColumnParallelLinear: + quant_config = kwargs.get("quant_config", None) + if quant_config is None: + return ColumnParallelLinear(*args, **kwargs) + + name = quant_config.get_name() + if name not in _QUANTIZED_LINEAR_REGISTRY: + raise ValueError(f"No quantized linear is found for {name}") + + quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][0] + return quant_linear_cls(*args, **kwargs) + + @classmethod + def row(cls, *args, **kwargs) -> RowParallelLinear: + quant_config = kwargs.get("quant_config", None) + if quant_config is None: + return RowParallelLinear(*args, **kwargs) + + name = quant_config.get_name() + if name not in _QUANTIZED_LINEAR_REGISTRY: + raise ValueError(f"No quantized linear is found for {name}") + + quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][1] + return quant_linear_cls(*args, **kwargs) diff --git a/vllm/model_executor/layers/quantized_linear/awq.py b/vllm/model_executor/layers/quantized_linear/awq.py new file mode 100644 index 00000000..88c5790a --- /dev/null +++ b/vllm/model_executor/layers/quantized_linear/awq.py @@ -0,0 +1,102 @@ +from typing import Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import quantization_ops +from vllm.model_executor.parallel_utils.tensor_parallel.layers import ( + ColumnParallelLinear, RowParallelLinear) + + +class AWQColumnParallelLinear(ColumnParallelLinear): + + def create_weights(self, dtype: torch.dtype) -> None: + assert self.input_size % self.quant_config.weight_bits == 0 + assert (self.output_size_per_partition % + self.quant_config.pack_factor == 0) + self.qweight = Parameter( + torch.empty( + self.input_size, + self.output_size_per_partition // + self.quant_config.pack_factor, + device="cuda", + dtype=torch.int32, + ), + requires_grad=False, + ) + self.qzeros = Parameter( + torch.empty( + self.input_size // self.quant_config.group_size, + self.output_size_per_partition // + self.quant_config.pack_factor, + device="cuda", + dtype=torch.int32, + ), + requires_grad=False, + ) + self.scales = Parameter( + torch.empty( + self.input_size // self.quant_config.group_size, + self.output_size_per_partition, + device="cuda", + dtype=dtype, + ), + requires_grad=False, + ) + + def apply_weights( + self, + x: torch.Tensor, + bias: Optional[torch.Tensor], + ) -> torch.Tensor: + pack_factor = self.quant_config.pack_factor + out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor) + reshaped_x = x.reshape(-1, x.shape[-1]) + out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales, + self.qzeros, pack_factor) + if bias is not None: + out = out + bias + return out.reshape(out_shape) + + +class AWQRowParallelLinear(RowParallelLinear): + + def create_weights(self, dtype: torch.dtype) -> None: + assert (self.input_size_per_partition % + self.quant_config.weight_bits == 0) + assert self.output_size % self.quant_config.pack_factor == 0 + self.qweight = Parameter( + torch.empty( + self.input_size_per_partition, + self.output_size // self.quant_config.pack_factor, + device="cuda", + dtype=torch.int32, + ), + requires_grad=False, + ) + self.qzeros = Parameter( + torch.empty( + self.input_size_per_partition // self.quant_config.group_size, + self.output_size // self.quant_config.pack_factor, + device="cuda", + dtype=torch.int32, + ), + requires_grad=False, + ) + self.scales = Parameter( + torch.empty( + self.input_size_per_partition // self.quant_config.group_size, + self.output_size, + device="cuda", + dtype=dtype, + ), + requires_grad=False, + ) + + def apply_weights(self, x: torch.Tensor) -> torch.Tensor: + pack_factor = self.quant_config.pack_factor + out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor) + reshaped_x = x.reshape(-1, x.shape[-1]) + out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales, + self.qzeros, pack_factor) + return out.reshape(out_shape) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index cd6c6b67..30d1620d 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -8,7 +8,8 @@ from transformers import PretrainedConfig from vllm.config import ModelConfig from vllm.model_executor.models import * # pylint: disable=wildcard-import -from vllm.model_executor.weight_utils import initialize_dummy_weights +from vllm.model_executor.weight_utils import (get_quant_config, + initialize_dummy_weights) # TODO(woosuk): Lazy-load the model classes. _MODEL_REGISTRY = { @@ -30,6 +31,11 @@ _MODEL_REGISTRY = { "RWForCausalLM": FalconForCausalLM, } +# FIXME(woosuk): Remove this once all models support quantization. +_MODEL_CLASSES_SUPPORT_QUANTIZATION = [ + LlamaForCausalLM, +] + @contextlib.contextmanager def _set_default_torch_dtype(dtype: torch.dtype): @@ -52,10 +58,30 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: def get_model(model_config: ModelConfig) -> nn.Module: model_class = _get_model_architecture(model_config.hf_config) + + # Get the quantization config. + quant_config = None + if model_config.quantization is not None: + if model_class not in _MODEL_CLASSES_SUPPORT_QUANTIZATION: + raise ValueError( + f"Quantization is not supported for {model_class}.") + quant_config = get_quant_config(model_config.quantization, + model_config.model, + model_config.download_dir) + supported_dtypes = quant_config.get_supported_act_dtypes() + if model_config.dtype not in supported_dtypes: + raise ValueError( + f"{model_config.dtype} is not supported for quantization " + f"method {model_config.quantization}. Supported dtypes: " + f"{supported_dtypes}") + with _set_default_torch_dtype(model_config.dtype): # Create a model instance. # The weights will be initialized as empty tensors. - model = model_class(model_config.hf_config) + if model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION: + model = model_class(model_config.hf_config, quant_config) + else: + model = model_class(model_config.hf_config) if model_config.load_format == "dummy": model = model.cuda() # NOTE(woosuk): For accurate performance evaluation, we assign diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index a2804d88..e87f0073 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -36,13 +36,15 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import ( - load_tensor_parallel_weights, load_padded_tensor_parallel_vocab, - hf_model_weights_iterator) +from vllm.model_executor.layers.quantized_linear import ParallelLinear from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.parallel_utils.tensor_parallel import ( - VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) + VocabParallelEmbedding) +from vllm.model_executor.quantization_utils import QuantizationConfig +from vllm.model_executor.weight_utils import ( + load_tensor_parallel_weights, load_padded_tensor_parallel_vocab, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -55,18 +57,21 @@ class LlamaMLP(nn.Module): hidden_size: int, intermediate_size: int, hidden_act: str, - ): + quant_config: Optional[QuantizationConfig] = None, + ) -> None: super().__init__() - self.gate_up_proj = ColumnParallelLinear(hidden_size, - 2 * intermediate_size, - bias=False, - gather_output=False, - perform_initialization=False) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - input_is_parallel=True, - perform_initialization=False) + self.gate_up_proj = ParallelLinear.column(hidden_size, + 2 * intermediate_size, + bias=False, + gather_output=False, + perform_initialization=False, + quant_config=quant_config) + self.down_proj = ParallelLinear.row(intermediate_size, + hidden_size, + bias=False, + input_is_parallel=True, + perform_initialization=False, + quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -87,7 +92,8 @@ class LlamaAttention(nn.Module): num_heads: int, num_kv_heads: int, rope_theta: float = 10000, - ): + quant_config: Optional[QuantizationConfig] = None, + ) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -103,20 +109,22 @@ class LlamaAttention(nn.Module): self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta - self.qkv_proj = ColumnParallelLinear( + self.qkv_proj = ParallelLinear.column( hidden_size, (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim, bias=False, gather_output=False, perform_initialization=False, + quant_config=quant_config, ) - self.o_proj = RowParallelLinear( + self.o_proj = ParallelLinear.row( self.total_num_heads * self.head_dim, hidden_size, bias=False, input_is_parallel=True, perform_initialization=False, + quant_config=quant_config, ) self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_dim, @@ -144,7 +152,11 @@ class LlamaAttention(nn.Module): class LlamaDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig): + def __init__( + self, + config: LlamaConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: super().__init__() self.hidden_size = config.hidden_size # Requires transformers > 4.32.0 @@ -154,11 +166,13 @@ class LlamaDecoderLayer(nn.Module): num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, + quant_config=quant_config, ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, + quant_config=quant_config, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -195,7 +209,11 @@ class LlamaDecoderLayer(nn.Module): class LlamaModel(nn.Module): - def __init__(self, config: LlamaConfig): + def __init__( + self, + config: LlamaConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: super().__init__() self.config = config self.padding_idx = config.pad_token_id @@ -205,7 +223,8 @@ class LlamaModel(nn.Module): self.embed_tokens = VocabParallelEmbedding( vocab_size, config.hidden_size, perform_initialization=False) self.layers = nn.ModuleList([ - LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers) + LlamaDecoderLayer(config, quant_config) + for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -237,16 +256,23 @@ class LlamaModel(nn.Module): class LlamaForCausalLM(nn.Module): - def __init__(self, config): + def __init__( + self, + config: LlamaConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: super().__init__() self.config = config - self.model = LlamaModel(config) + self.quant_config = quant_config + self.model = LlamaModel(config, quant_config) vocab_size = ((config.vocab_size + 63) // 64) * 64 - self.lm_head = ColumnParallelLinear(config.hidden_size, - vocab_size, - bias=False, - gather_output=False, - perform_initialization=False) + # NOTE: The LM head is not quantized. + self.lm_head = ParallelLinear.column(config.hidden_size, + vocab_size, + bias=False, + gather_output=False, + perform_initialization=False, + quant_config=None) self.sampler = Sampler(config.vocab_size) def forward( @@ -263,16 +289,28 @@ class LlamaForCausalLM(nn.Module): input_metadata) return next_tokens - _column_parallel_weights = [ - "qkv_proj.weight", "gate_proj.weight", "up_proj.weight" - ] - _row_parallel_weights = ["o_proj.weight", "down_proj.weight"] + _column_parallel_layers = [] + _row_parallel_layers = ["o_proj", "down_proj"] def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): + if self.quant_config is None: + weight_suffixes = ["weight"] + else: + weight_suffixes = self.quant_config.get_tp_tensor_names() + + column_parallel_weights: List[str] = [] + for layer in self._column_parallel_layers: + for suffix in weight_suffixes: + column_parallel_weights.append(f"{layer}.{suffix}") + row_parallel_weights: List[str] = [] + for layer in self._row_parallel_layers: + for suffix in weight_suffixes: + row_parallel_weights.append(f"{layer}.{suffix}") + tp_size = get_tensor_model_parallel_world_size() tensor_model_parallel_rank = get_tensor_model_parallel_rank() q_proj_shard_size = (self.config.hidden_size // tp_size) @@ -293,11 +331,25 @@ class LlamaForCausalLM(nn.Module): if "rotary_emb.inv_freq" in name: continue + is_packed = False + is_transposed = False + if self.quant_config is not None: + is_packed = self.quant_config.is_packed(name) + is_transposed = self.quant_config.is_transposed(name) + if is_transposed: + loaded_weight = loaded_weight.T + is_attention_weight = False for weight_name, shard_size, offset in attention_weight_specs: if weight_name not in name: continue param = state_dict[name.replace(weight_name, "qkv_proj")] + if is_transposed: + param = param.T + + if is_packed: + shard_size //= self.quant_config.pack_factor + offset //= self.quant_config.pack_factor loaded_weight = loaded_weight[ shard_size * tensor_model_parallel_rank:shard_size * @@ -316,6 +368,9 @@ class LlamaForCausalLM(nn.Module): if weight_name not in name: continue param = state_dict[name.replace(weight_name, "gate_up_proj")] + if is_transposed: + param = param.T + shard_size = param.shape[0] // 2 loaded_weight = loaded_weight[ shard_size * tensor_model_parallel_rank:shard_size * @@ -330,6 +385,8 @@ class LlamaForCausalLM(nn.Module): continue param = state_dict[name] + if is_transposed: + param = param.T if "embed_tokens" in name or "lm_head" in name: load_padded_tensor_parallel_vocab(param, loaded_weight, @@ -337,6 +394,6 @@ class LlamaForCausalLM(nn.Module): continue load_tensor_parallel_weights(param, loaded_weight, name, - self._column_parallel_weights, - self._row_parallel_weights, + column_parallel_weights, + row_parallel_weights, tensor_model_parallel_rank) diff --git a/vllm/model_executor/parallel_utils/tensor_parallel/layers.py b/vllm/model_executor/parallel_utils/tensor_parallel/layers.py index 0b4d32b6..0f144f77 100644 --- a/vllm/model_executor/parallel_utils/tensor_parallel/layers.py +++ b/vllm/model_executor/parallel_utils/tensor_parallel/layers.py @@ -4,7 +4,7 @@ # Parts of the code here are adapted from PyTorch # repo: https://github.com/pytorch/pytorch - +from typing import Optional import torch import torch.nn.functional as F @@ -16,13 +16,11 @@ from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size, ) from .mappings import ( - copy_to_tensor_model_parallel_region, gather_from_tensor_model_parallel_region, reduce_from_tensor_model_parallel_region, scatter_to_tensor_model_parallel_region, ) -from .random import get_cuda_rng_tracker from .utils import ( divide, VocabUtility, @@ -65,59 +63,6 @@ def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor): maybe_copy(attribute) -def _initialize_affine_weight_gpu(weight, init_method, - partition_dim, stride=1): - """Initialize affine weight for model parallel on GPU.""" - - set_tensor_model_parallel_attributes(tensor=weight, - is_parallel=True, - dim=partition_dim, - stride=stride) - - with get_cuda_rng_tracker().fork(): - init_method(weight) - - -def _initialize_affine_weight_cpu(weight, output_size, input_size, - per_partition_size, partition_dim, - init_method, stride=1, - return_master_weight=False, - *, params_dtype=None): - """Initialize affine weight for model parallel. - - Build the master weight on all processes and scatter - the relevant chunk.""" - - set_tensor_model_parallel_attributes(tensor=weight, - is_parallel=True, - dim=partition_dim, - stride=stride) - - if params_dtype is None: - params_dtype = torch.get_default_dtype() - - # Initialize master weight - master_weight = torch.empty(output_size, input_size, - dtype=torch.float, - requires_grad=False) - init_method(master_weight) - master_weight = master_weight.to(dtype=params_dtype) - - # Split and copy - per_partition_per_stride_size = divide(per_partition_size, stride) - weight_list = torch.split(master_weight, per_partition_per_stride_size, - dim=partition_dim) - rank = get_tensor_model_parallel_rank() - world_size = get_tensor_model_parallel_world_size() - my_weight_list = weight_list[rank::world_size] - - with torch.no_grad(): - torch.cat(my_weight_list, dim=partition_dim, out=weight) - if return_master_weight: - return master_weight - return None - - class VocabParallelEmbedding(torch.nn.Module): """Embedding parallelized in the vocabulary dimension. @@ -140,6 +85,9 @@ class VocabParallelEmbedding(torch.nn.Module): use_cpu_initialization: bool=False, perform_initialization: bool=True): super(VocabParallelEmbedding, self).__init__() + assert not perform_initialization + assert not use_cpu_initialization + # Keep the input dimensions. self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim @@ -162,24 +110,10 @@ class VocabParallelEmbedding(torch.nn.Module): self.num_embeddings_per_partition = self.vocab_end_index - \ self.vocab_start_index - # Allocate weights and initialize. - if use_cpu_initialization: - self.weight = Parameter(torch.empty( - self.num_embeddings_per_partition, self.embedding_dim, - dtype=params_dtype)) - if perform_initialization: - _initialize_affine_weight_cpu( - self.weight, self.num_embeddings, self.embedding_dim, - self.num_embeddings_per_partition, 0, init_method, - params_dtype=params_dtype) - else: - self.weight = Parameter(torch.empty( - self.num_embeddings_per_partition, self.embedding_dim, - device=torch.cuda.current_device(), dtype=params_dtype)) - if perform_initialization: - _initialize_affine_weight_gpu(self.weight, init_method, - partition_dim=0, stride=1) - + self.weight = Parameter(torch.empty( + self.num_embeddings_per_partition, self.embedding_dim, + device=torch.cuda.current_device(), dtype=params_dtype)) + def forward(self, input_): if self.tensor_model_parallel_size > 1: # Build the mask. @@ -239,8 +173,11 @@ class ColumnParallelLinear(torch.nn.Module): params_dtype=None, use_cpu_initialization=False, perform_initialization=True, + quant_config=None, ): super(ColumnParallelLinear, self).__init__() + assert not perform_initialization + assert not use_cpu_initialization # Keep input parameters self.input_size = input_size @@ -250,6 +187,7 @@ class ColumnParallelLinear(torch.nn.Module): self.world_size = get_tensor_model_parallel_world_size() self.output_size_per_partition = divide(output_size, self.world_size) self.skip_bias_add = skip_bias_add + self.quant_config = quant_config if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -257,33 +195,13 @@ class ColumnParallelLinear(torch.nn.Module): # Parameters. # Note: torch.nn.functional.linear performs XA^T + b and as a result # we allocate the transpose. - # Initialize weight. - if use_cpu_initialization: - self.weight = Parameter(torch.empty(self.output_size_per_partition, - self.input_size, - dtype=params_dtype)) - if perform_initialization: - self.master_weight = _initialize_affine_weight_cpu( - self.weight, self.output_size, self.input_size, - self.output_size_per_partition, 0, init_method, - stride=stride, return_master_weight=keep_master_weight_for_test) - else: - self.weight = Parameter(torch.empty( - self.output_size_per_partition, self.input_size, - device=torch.cuda.current_device(), dtype=params_dtype)) - if perform_initialization: - _initialize_affine_weight_gpu(self.weight, init_method, - partition_dim=0, stride=stride) + self.create_weights(params_dtype) if bias: - if use_cpu_initialization: - self.bias = Parameter(torch.empty( - self.output_size_per_partition, dtype=params_dtype)) - else: - self.bias = Parameter(torch.empty( - self.output_size_per_partition, - device=torch.cuda.current_device(), - dtype=params_dtype)) + self.bias = Parameter(torch.empty( + self.output_size_per_partition, + device=torch.cuda.current_device(), + dtype=params_dtype)) set_tensor_model_parallel_attributes(self.bias, True, 0, stride) # Always initialize bias to zero. with torch.no_grad(): @@ -291,6 +209,17 @@ class ColumnParallelLinear(torch.nn.Module): else: self.register_parameter('bias', None) + def create_weights(self, dtype: torch.dtype) -> None: + self.weight = Parameter(torch.empty( + self.output_size_per_partition, self.input_size, + device=torch.cuda.current_device(), dtype=dtype)) + + def apply_weights( + self, + x: torch.Tensor, + bias: Optional[torch.Tensor], + ) -> torch.Tensor: + return F.linear(x, self.weight, bias) def forward(self, input_): """Forward of ColumnParallelLinear @@ -306,7 +235,7 @@ class ColumnParallelLinear(torch.nn.Module): input_parallel = input_ # Matrix multiply. - output_parallel = F.linear(input_parallel, self.weight, bias) + output_parallel = self.apply_weights(input_parallel, bias) if self.gather_output: # All-gather across the partitions. output = gather_from_tensor_model_parallel_region(output_parallel) @@ -361,8 +290,11 @@ class RowParallelLinear(torch.nn.Module): use_cpu_initialization=False, perform_initialization=True, reduce_results=True, + quant_config=None, ): super(RowParallelLinear, self).__init__() + assert not perform_initialization + assert not use_cpu_initialization # Keep input parameters self.input_size = input_size @@ -376,47 +308,32 @@ class RowParallelLinear(torch.nn.Module): self.world_size = get_tensor_model_parallel_world_size() self.input_size_per_partition = divide(input_size, self.world_size) self.skip_bias_add = skip_bias_add + self.quant_config = quant_config + + self.create_weights(params_dtype) if not reduce_results and (bias and not skip_bias_add): raise ValueError("When not reduce the results, adding bias to the " "results can lead to incorrect results") - # Parameters. - # Note: torch.nn.functional.linear performs XA^T + b and as a result - # we allocate the transpose. - # Initialize weight. - if use_cpu_initialization: - self.weight = Parameter(torch.empty(self.output_size, - self.input_size_per_partition, - dtype=params_dtype)) - if perform_initialization: - self.master_weight = _initialize_affine_weight_cpu( - self.weight, self.output_size, self.input_size, - self.input_size_per_partition, 1, init_method, - stride=stride, return_master_weight=keep_master_weight_for_test, - params_dtype=params_dtype) - else: - self.weight = Parameter(torch.empty( - self.output_size, self.input_size_per_partition, - device=torch.cuda.current_device(), dtype=params_dtype)) - if perform_initialization: - _initialize_affine_weight_gpu(self.weight, init_method, - partition_dim=1, stride=stride) if bias: - if use_cpu_initialization: - self.bias = Parameter(torch.empty(self.output_size, - dtype=params_dtype)) - else: - self.bias = Parameter(torch.empty( - self.output_size, device=torch.cuda.current_device(), - dtype=params_dtype)) + self.bias = Parameter(torch.empty( + self.output_size, device=torch.cuda.current_device(), + dtype=params_dtype)) # Always initialize bias to zero. with torch.no_grad(): self.bias.zero_() else: self.register_parameter('bias', None) - self.weight_t = self.weight.t() + + def create_weights(self, dtype: torch.dtype) -> None: + self.weight = Parameter(torch.empty( + self.output_size, self.input_size_per_partition, + device=torch.cuda.current_device(), dtype=dtype)) + + def apply_weights(self, x: torch.Tensor) -> torch.Tensor: + return F.linear(x, self.weight) def forward(self, input_): """Forward of RowParallelLinear @@ -434,7 +351,7 @@ class RowParallelLinear(torch.nn.Module): else: input_parallel = scatter_to_tensor_model_parallel_region(input_) # Matrix multiply. - output_parallel = F.linear(input_parallel, self.weight) + output_parallel = self.apply_weights(input_parallel) if self.reduce_results and self.world_size > 1: output_ = reduce_from_tensor_model_parallel_region(output_parallel) else: diff --git a/vllm/model_executor/quantization_utils/__init__.py b/vllm/model_executor/quantization_utils/__init__.py new file mode 100644 index 00000000..df67758f --- /dev/null +++ b/vllm/model_executor/quantization_utils/__init__.py @@ -0,0 +1,20 @@ +from typing import Type + +from vllm.model_executor.quantization_utils.awq import AWQConfig +from vllm.model_executor.quantization_utils.base import QuantizationConfig + +_QUANTIZATION_REGISTRY = { + "awq": AWQConfig, +} + + +def get_quant_class(quantization: str) -> Type[QuantizationConfig]: + if quantization not in _QUANTIZATION_REGISTRY: + raise ValueError(f"Invalid quantization method: {quantization}") + return _QUANTIZATION_REGISTRY[quantization] + + +__all__ = [ + "QuantizationConfig", + "get_quant_class", +] diff --git a/vllm/model_executor/quantization_utils/awq.py b/vllm/model_executor/quantization_utils/awq.py new file mode 100644 index 00000000..ed8987e1 --- /dev/null +++ b/vllm/model_executor/quantization_utils/awq.py @@ -0,0 +1,67 @@ +from typing import Any, Dict, List + +import torch + +from vllm.model_executor.quantization_utils.base import QuantizationConfig + + +class AWQConfig(QuantizationConfig): + """Config class for AWQ. + + Reference: https://arxiv.org/abs/2306.00978 + """ + + def __init__( + self, + weight_bits: int, + group_size: int, + zero_point: bool, + ) -> None: + self.weight_bits = weight_bits + self.group_size = group_size + self.zero_point = zero_point + + if self.weight_bits != 4: + raise ValueError( + "Currently, only 4-bit weight quantization is supported for " + f"AWQ, but got {self.weight_bits} bits.") + self.pack_factor = 32 // self.weight_bits + + def __repr__(self) -> str: + return (f"AWQConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"zero_point={self.zero_point})") + + @classmethod + def get_name(cls) -> str: + return "awq" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half] + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [ + "quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq + "quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq # pylint: disable=line-too-long + ] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": + weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) + group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) + zero_point = cls.get_from_keys(config, ["zero_point"]) + return cls(weight_bits, group_size, zero_point) + + @classmethod + def get_packed_tensor_names(cls) -> List[str]: + return ["qweight", "qzeros"] + + @classmethod + def get_transposed_tensor_names(cls) -> List[str]: + return ["qweight", "qzeros", "scales"] + + @classmethod + def get_tp_tensor_names(cls) -> List[str]: + return ["qweight", "qzeros", "scales"] diff --git a/vllm/model_executor/quantization_utils/base.py b/vllm/model_executor/quantization_utils/base.py new file mode 100644 index 00000000..cb406f4c --- /dev/null +++ b/vllm/model_executor/quantization_utils/base.py @@ -0,0 +1,65 @@ +from typing import Any, Dict, List + +import torch + + +class QuantizationConfig: + + @classmethod + def get_name(cls) -> str: + """Name of the quantization method.""" + raise NotImplementedError + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + """List of supported activation dtypes.""" + raise NotImplementedError + + @classmethod + def get_config_filenames(cls) -> List[str]: + """List of filenames to search for in the model directory.""" + raise NotImplementedError + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig": + """Create a config class from the model's quantization config.""" + raise NotImplementedError + + @staticmethod + def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: + """Get a value from the model's quantization config.""" + for key in keys: + if key in config: + return config[key] + raise ValueError(f"Cannot find any of {keys} in the model's " + "quantization config.") + + @classmethod + def get_packed_tensor_names(cls) -> List[str]: + raise NotImplementedError + + @classmethod + def is_packed(cls, tensor_name: str) -> bool: + """Returns True if a tensor is packed. + + A tensor is considered packed if each element in the tensor is a + packed representation of multiple elements in the original tensor. + For example, an INT32 element in the tensor may represent 8 INT4 + elements in the original tensor. + """ + return any(tag in tensor_name for tag in cls.get_packed_tensor_names()) + + @classmethod + def get_transposed_tensor_names(cls) -> List[str]: + raise NotImplementedError + + @classmethod + def is_transposed(cls, tensor_name: str) -> bool: + """Returns True if a tensor is transposed relative to nn.Linear.weight. + """ + return any(tag in tensor_name + for tag in cls.get_transposed_tensor_names()) + + @classmethod + def get_tp_tensor_names(cls) -> List[str]: + raise NotImplementedError diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index c99f02bd..74de9684 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -4,7 +4,7 @@ import glob import json import os from collections import defaultdict -from typing import Iterator, List, Optional, Tuple, Any +from typing import Any, Iterator, List, Optional, Tuple from huggingface_hub import snapshot_download from safetensors.torch import load_file, save_file, safe_open @@ -13,6 +13,8 @@ import torch from tqdm.auto import tqdm from vllm.logger import init_logger +from vllm.model_executor.quantization_utils import get_quant_class +from vllm.model_executor.quantization_utils.base import QuantizationConfig logger = init_logger(__name__) @@ -44,7 +46,7 @@ def _shared_pointers(tensors): def convert_bin_to_safetensor_file( pt_filename: str, sf_filename: str, -): +) -> None: loaded = torch.load(pt_filename, map_location="cpu") if "state_dict" in loaded: loaded = loaded["state_dict"] @@ -78,16 +80,55 @@ def convert_bin_to_safetensor_file( raise RuntimeError(f"The output tensors do not match for key {k}") +# TODO(woosuk): Move this to other place. +def get_quant_config( + quantization: str, + model_name_or_path: str, + cache_dir: Optional[str] = None, +) -> QuantizationConfig: + is_local = os.path.isdir(model_name_or_path) + if not is_local: + # Download the config files. + with get_lock(model_name_or_path, cache_dir): + hf_folder = snapshot_download(model_name_or_path, + allow_patterns="*.json", + cache_dir=cache_dir, + tqdm_class=Disabledtqdm) + else: + hf_folder = model_name_or_path + config_files = glob.glob(os.path.join(hf_folder, "*.json")) + + quant_cls = get_quant_class(quantization) + quant_config_files = [ + f for f in config_files if any( + f.endswith(x) for x in quant_cls.get_config_filenames()) + ] + if len(quant_config_files) == 0: + raise ValueError(f"Cannot find the config file for {quantization}") + if len(quant_config_files) > 1: + raise ValueError(f"Found multiple config files for {quantization}: " + f"{quant_config_files}") + + quant_config_file = quant_config_files[0] + with open(quant_config_file, "r") as f: + config = json.load(f) + return quant_cls.from_config(config) + + def prepare_hf_model_weights( model_name_or_path: str, cache_dir: Optional[str] = None, use_safetensors: bool = False, fall_back_to_pt: bool = True, revision: Optional[str] = None, -): +) -> Tuple[str, List[str], bool]: # Download model weights from huggingface. is_local = os.path.isdir(model_name_or_path) - allow_patterns = "*.safetensors" if use_safetensors else "*.bin" + if use_safetensors: + allow_patterns = ["*.safetensors"] + else: + # Some quantized models use .pt files for storing the weights. + allow_patterns = ["*.bin", "*.pt"] if not is_local: # Use file lock to prevent multiple processes from # downloading the same model weights at the same time. @@ -99,7 +140,9 @@ def prepare_hf_model_weights( revision=revision) else: hf_folder = model_name_or_path - hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns)) + hf_weights_files: List[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) if not use_safetensors: hf_weights_files = [ x for x in hf_weights_files if not x.endswith("training_args.bin")