Implement AWQ quantization support for LLaMA (#1032)

Co-authored-by: Robert Irvine <robert@seamlessml.com>
Co-authored-by: root <rirv938@gmail.com>
Co-authored-by: Casper <casperbh.96@gmail.com>
Co-authored-by: julian-q <julianhquevedo@gmail.com>
This commit is contained in:
Woosuk Kwon 2023-09-16 00:03:37 -07:00 committed by GitHub
parent b9fe4616f9
commit e3e79e9e8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 1178 additions and 208 deletions

4
.gitignore vendored
View File

@ -173,3 +173,7 @@ cython_debug/
# Sphinx documentation # Sphinx documentation
_build/ _build/
# vim swap files
*.swo
*.swp

View File

@ -18,6 +18,7 @@ def main(args: argparse.Namespace):
llm = LLM( llm = LLM(
model=args.model, model=args.model,
tokenizer=args.tokenizer, tokenizer=args.tokenizer,
quantization=args.quantization,
tensor_parallel_size=args.tensor_parallel_size, tensor_parallel_size=args.tensor_parallel_size,
max_num_seqs=args.batch_size, max_num_seqs=args.batch_size,
max_num_batched_tokens=args.batch_size * args.input_len, max_num_batched_tokens=args.batch_size * args.input_len,
@ -63,19 +64,28 @@ def main(args: argparse.Namespace):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Benchmark the latency of processing a single batch of ' description='Benchmark the latency of processing a single batch of '
'requests till completion.') 'requests till completion.')
parser.add_argument('--model', type=str, default='facebook/opt-125m') parser.add_argument('--model', type=str, default='facebook/opt-125m')
parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=['awq', None],
default=None)
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32) parser.add_argument('--input-len', type=int, default=32)
parser.add_argument('--output-len', type=int, default=128) parser.add_argument('--output-len', type=int, default=128)
parser.add_argument('--batch-size', type=int, default=8) parser.add_argument('--batch-size', type=int, default=8)
parser.add_argument('--n', type=int, default=1, parser.add_argument('--n',
type=int,
default=1,
help='Number of generated sequences per prompt.') help='Number of generated sequences per prompt.')
parser.add_argument('--use-beam-search', action='store_true') parser.add_argument('--use-beam-search', action='store_true')
parser.add_argument('--num-iters', type=int, default=3, parser.add_argument('--num-iters',
type=int,
default=3,
help='Number of iterations to run.') help='Number of iterations to run.')
parser.add_argument('--trust-remote-code', action='store_true', parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface') help='trust remote code from huggingface')
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)

View File

@ -3,7 +3,7 @@ import argparse
import json import json
import random import random
import time import time
from typing import List, Tuple from typing import List, Optional, Tuple
import torch import torch
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
@ -22,15 +22,10 @@ def sample_requests(
with open(dataset_path) as f: with open(dataset_path) as f:
dataset = json.load(f) dataset = json.load(f)
# Filter out the conversations with less than 2 turns. # Filter out the conversations with less than 2 turns.
dataset = [ dataset = [data for data in dataset if len(data["conversations"]) >= 2]
data for data in dataset
if len(data["conversations"]) >= 2
]
# Only keep the first two turns of each conversation. # Only keep the first two turns of each conversation.
dataset = [ dataset = [(data["conversations"][0]["value"],
(data["conversations"][0]["value"], data["conversations"][1]["value"]) data["conversations"][1]["value"]) for data in dataset]
for data in dataset
]
# Tokenize the prompts and completions. # Tokenize the prompts and completions.
prompts = [prompt for prompt, _ in dataset] prompts = [prompt for prompt, _ in dataset]
@ -63,6 +58,7 @@ def run_vllm(
requests: List[Tuple[str, int, int]], requests: List[Tuple[str, int, int]],
model: str, model: str,
tokenizer: str, tokenizer: str,
quantization: Optional[str],
tensor_parallel_size: int, tensor_parallel_size: int,
seed: int, seed: int,
n: int, n: int,
@ -72,6 +68,7 @@ def run_vllm(
llm = LLM( llm = LLM(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size, tensor_parallel_size=tensor_parallel_size,
seed=seed, seed=seed,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
@ -111,8 +108,8 @@ def run_hf(
trust_remote_code: bool, trust_remote_code: bool,
) -> float: ) -> float:
assert not use_beam_search assert not use_beam_search
llm = AutoModelForCausalLM.from_pretrained(model, llm = AutoModelForCausalLM.from_pretrained(
torch_dtype=torch.float16, trust_remote_code=trust_remote_code) model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
if llm.config.model_type == "llama": if llm.config.model_type == "llama":
# To enable padding in the HF backend. # To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
@ -132,13 +129,14 @@ def run_hf(
if len(batch) < max_batch_size and i != len(requests) - 1: if len(batch) < max_batch_size and i != len(requests) - 1:
# Check if we can add more requests to the batch. # Check if we can add more requests to the batch.
_, next_prompt_len, next_output_len = requests[i + 1] _, next_prompt_len, next_output_len = requests[i + 1]
if (max(max_prompt_len, next_prompt_len) + max( if (max(max_prompt_len, next_prompt_len) +
max_output_len, next_output_len)) <= 2048: max(max_output_len, next_output_len)) <= 2048:
# We can add more requests to the batch. # We can add more requests to the batch.
continue continue
# Generate the sequences. # Generate the sequences.
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids input_ids = tokenizer(batch, return_tensors="pt",
padding=True).input_ids
llm_outputs = llm.generate( llm_outputs = llm.generate(
input_ids=input_ids.cuda(), input_ids=input_ids.cuda(),
do_sample=not use_beam_search, do_sample=not use_beam_search,
@ -165,44 +163,58 @@ def main(args: argparse.Namespace):
random.seed(args.seed) random.seed(args.seed)
# Sample the requests. # Sample the requests.
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code) tokenizer = get_tokenizer(args.tokenizer,
trust_remote_code=args.trust_remote_code)
requests = sample_requests(args.dataset, args.num_prompts, tokenizer) requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
if args.backend == "vllm": if args.backend == "vllm":
elapsed_time = run_vllm( elapsed_time = run_vllm(requests, args.model, args.tokenizer,
requests, args.model, args.tokenizer, args.tensor_parallel_size, args.quantization, args.tensor_parallel_size,
args.seed, args.n, args.use_beam_search, args.trust_remote_code) args.seed, args.n, args.use_beam_search,
args.trust_remote_code)
elif args.backend == "hf": elif args.backend == "hf":
assert args.tensor_parallel_size == 1 assert args.tensor_parallel_size == 1
elapsed_time = run_hf( elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
requests, args.model, tokenizer, args.n, args.use_beam_search, args.use_beam_search, args.hf_max_batch_size,
args.hf_max_batch_size, args.trust_remote_code) args.trust_remote_code)
else: else:
raise ValueError(f"Unknown backend: {args.backend}") raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum( total_num_tokens = sum(prompt_len + output_len
prompt_len + output_len for _, prompt_len, output_len in requests)
for _, prompt_len, output_len in requests
)
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s") f"{total_num_tokens / elapsed_time:.2f} tokens/s")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark the throughput.") parser = argparse.ArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--backend", type=str, choices=["vllm", "hf"], parser.add_argument("--backend",
type=str,
choices=["vllm", "hf"],
default="vllm") default="vllm")
parser.add_argument("--dataset", type=str, required=True, parser.add_argument("--dataset",
type=str,
required=True,
help="Path to the dataset.") help="Path to the dataset.")
parser.add_argument("--model", type=str, default="facebook/opt-125m") parser.add_argument("--model", type=str, default="facebook/opt-125m")
parser.add_argument("--tokenizer", type=str, default=None) parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=['awq', None],
default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n", type=int, default=1, parser.add_argument("--n",
type=int,
default=1,
help="Number of generated sequences per prompt.") help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true") parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--num-prompts", type=int, default=1000, parser.add_argument("--num-prompts",
type=int,
default=1000,
help="Number of prompts to process.") help="Number of prompts to process.")
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--hf-max-batch-size", type=int, default=None, parser.add_argument("--hf-max-batch-size",
type=int,
default=None,
help="Maximum batch size for HF backend.") help="Maximum batch size for HF backend.")
parser.add_argument('--trust-remote-code', parser.add_argument('--trust-remote-code',
action='store_true', action='store_true',
@ -215,6 +227,8 @@ if __name__ == "__main__":
elif args.backend == "hf": elif args.backend == "hf":
if args.hf_max_batch_size is None: if args.hf_max_batch_size is None:
raise ValueError("HF max batch size is required for HF backend.") raise ValueError("HF max batch size is required for HF backend.")
if args.quantization is not None:
raise ValueError("Quantization is only for vLLM backend.")
if args.tokenizer is None: if args.tokenizer is None:
args.tokenizer = args.model args.tokenizer = args.model

15
csrc/quantization.cpp Normal file
View File

@ -0,0 +1,15 @@
#include <torch/extension.h>
torch::Tensor awq_gemm(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"awq_gemm",
&awq_gemm,
"Quantized GEMM for AWQ");
}

View File

@ -0,0 +1,79 @@
/*
Adapted from https://github.com/mit-han-lab/llm-awq
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#pragma once
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
{
uint4 result;
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
// immediately before required.
const uint32_t top_i4s = i4s >> 8;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
// half2 ctor. In this case, I chose performance reliability over code readability.
// This is the half2 {1032, 1032} represented as an integer.
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
// This is the half2 {-72, -72} represented as an integer.
// static constexpr uint32_t NEG_72 = 0xd480d480;
// Haotian: Let's use {-64, -64}.
static constexpr uint32_t NEG_64 = 0xd400d400;
// Finally, we construct the output numbers.
// Convert elt_01
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_23
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Convert elt_45
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
return result;
}

View File

@ -0,0 +1,477 @@
/*
Adapted from https://github.com/mit-han-lab/llm-awq
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include "dequantize.cuh"
#include <cuda_fp16.h>
// Pack two half values.
static inline __device__ __host__ unsigned
__pack_half2(const half x, const half y) {
unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0;
}
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
{
static constexpr uint32_t ZERO = 0x0;
float C_warp[32];
__shared__ half A_shared[16 * (32 + 8)];
__shared__ half B_shared[32 * (128 + 8)];
__shared__ half scaling_factors_shared[128];
__shared__ half zeros_shared[128];
int j_factors1 = ((OC + 128 - 1) / 128);
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
half A_shared_warp[8];
half B_shared_warp[32];
for (int j_0_4_init = 0; j_0_4_init < 4; ++j_0_4_init) {
for (int i = 0; i < 8; ++i) {
C_warp[(j_0_4_init * 8) + i] = 0.0;
}
}
static constexpr int row_stride_warp = 32 * 8 / 32;
static constexpr int row_stride = 2 * 32 * 8 / 128;
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 128;
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
// bool wb_C_flag = (threadIdx.x / 4) < M;
half* A_ptr = A
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
+ (((int)threadIdx.x) % (32 / 8)) * 8;
int* B_ptr = B
+ ((int)threadIdx.y) * (OC / 8) * 2
+ (((int)threadIdx.x) / (128 / 8)) * (OC / 8)
+ (((int)blockIdx_y) % j_factors1) * (128 / 8)
+ (((int)threadIdx.x) % (128 / 8)) * 1;
// Why * 1 in the above line?
half* A_shared_ptr = A_shared
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
half* B_shared_ptr = B_shared
+ ((int)threadIdx.y) * (row_stride / 2) * (128 + 8)
+ (((int)threadIdx.x) / (128 / 8)) * (128 + 8)
+ (((int)threadIdx.x) % (128 / 8)) * 8;
int* zeros_ptr = zeros
+ (((int)blockIdx_y) % j_factors1) * (128 / 8)
+ ((int)threadIdx.x) % (128 / 8);
half* scaling_factors_ptr = scaling_factors
+ (((int)blockIdx_y) % j_factors1) * (128)
+ (((int)threadIdx.x) % (128 / 8)) * 8;
half* C_ptr = C
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim
+ (((int)blockIdx_y) % j_factors1) * 128
+ ((int)threadIdx.y) * 64
+ (((int)threadIdx.x) % 4) * 2;
// preload s.f. and zeros
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
__syncthreads();
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
if (ld_A_flag)
{
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
}
else
{
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
}
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
/*
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
}
*/
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) {
// B: 32 x 136 (128+8) float16
// each warp: 32 x 4
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
// - zero and * scale
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
/*
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
}
*/
// write back
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (128 + 8)) = B_loaded_fp16;
}
__syncthreads();
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
);
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
: "r"(addr)
);
}
for (int ax1_0 = 0; ax1_0 < 4; ++ax1_0) {
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(B_shared[(((k_0_1 * 2176) + (((int)threadIdx.y) * 64)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 136) + ((((int)threadIdx.x) >> 4) * 8))))
);
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
: "r"(addr)
);
}
}
for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) {
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
}
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
}
}
}
}
// TODO: Shang: Hoist loop invariance.
for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
for (int local_id = 0; local_id < 8; ++local_id) {
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
if (row_offset < M)
{
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
}
}
}
}
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
{
static constexpr uint32_t ZERO = 0x0;
float C_warp[32];
__shared__ half A_shared[16 * (32 + 8)];
__shared__ half B_shared[32 * (64 + 8)];
__shared__ half scaling_factors_shared[64];
__shared__ half zeros_shared[64];
int j_factors1 = ((OC + 64 - 1) / 64);
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
half A_shared_warp[8];
half B_shared_warp[16];
for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) {
for (int i = 0; i < 8; ++i) {
C_warp[(j_0_4_init * 8) + i] = 0.0;
}
}
static constexpr int row_stride_warp = 32 * 8 / 32;
static constexpr int row_stride = 2 * 32 * 8 / 64;
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64;
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
// bool wb_C_flag = (threadIdx.x / 4) < M;
half* A_ptr = A
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
+ (((int)threadIdx.x) % (32 / 8)) * 8;
int* B_ptr = B
+ ((int)threadIdx.y) * (OC / 8) * 4
+ (((int)threadIdx.x) / (64 / 8)) * (OC / 8)
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
+ (((int)threadIdx.x) % (64 / 8)) * 1;
// Why * 1 in the above line?
half* A_shared_ptr = A_shared
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
half* B_shared_ptr = B_shared
+ ((int)threadIdx.y) * (row_stride / 2) * (64 + 8)
+ (((int)threadIdx.x) / (64 / 8)) * (64 + 8)
+ (((int)threadIdx.x) % (64 / 8)) * 8;
int* zeros_ptr = zeros
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
+ ((int)threadIdx.x) % (64 / 8);
half* scaling_factors_ptr = scaling_factors
+ (((int)blockIdx_y) % j_factors1) * (64)
+ (((int)threadIdx.x) % (64 / 8)) * 8;
half* C_ptr = C
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim
+ (((int)blockIdx_y) % j_factors1) * 64
+ ((int)threadIdx.y) * 32
+ (((int)threadIdx.x) % 4) * 2;
// preload s.f. and zeros
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
__syncthreads();
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
if (ld_A_flag)
{
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
}
else
{
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
}
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
/*
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
}
*/
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) {
// B: 32 x 136 (128+8) float16
// each warp: 32 x 4
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
// - zero and * scale
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
/*
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
}
*/
// write back
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_fp16;
}
__syncthreads();
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1)
{
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
);
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
: "r"(addr)
);
}
for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0)
{
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8))))
);
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
: "r"(addr)
);
}
}
for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4)
{
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
}
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
}
}
}
}
// TODO: Shang: Hoist loop invariance.
for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) {
for (int local_id = 0; local_id < 8; ++local_id) {
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
if (row_offset < M)
{
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
}
}
}
}
// in_feats: M, IC [float16]
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
// scaling_factors: IC // G, OC [float16]
// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
// assume that batch_size < 16 for now
torch::Tensor awq_gemm(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters)
{
int num_in_feats = _in_feats.size(0);
int num_in_channels = _in_feats.size(1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
int num_out_feats = _out_feats.size(-2);
int num_out_channels = _out_feats.size(-1);
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
int group_size = num_in_channels / _scaling_factors.size(0);
if (num_out_channels % 64 != 0)
throw std::invalid_argument("OC is not multiple of cta_N = 64");
if (num_out_channels % 8 != 0)
throw std::invalid_argument("OC is not multiple of pack_num = 8");
if (group_size % 32 != 0)
throw std::invalid_argument("Group size should be a multiple of 32");
if (num_out_channels % group_size != 0)
throw std::invalid_argument("OC is not multiple of Group size");
if (num_out_channels % 128 == 0)
{
int j_factors1 = num_out_channels / 128 / 1;
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2);
gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
}
else if (num_out_channels % 64 == 0)
{
int j_factors1 = num_out_channels / 64 / 1;
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2);
gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>(
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
}
return _out_feats.sum(0);
}

View File

@ -146,6 +146,20 @@ activation_extension = CUDAExtension(
) )
ext_modules.append(activation_extension) ext_modules.append(activation_extension)
# Quantization kernels.
quantization_extension = CUDAExtension(
name="vllm.quantization_ops",
sources=[
"csrc/quantization.cpp",
"csrc/quantization/awq/gemm_kernels.cu",
],
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(quantization_extension)
def get_path(*filepath) -> str: def get_path(*filepath) -> str:
return os.path.join(ROOT_DIR, *filepath) return os.path.join(ROOT_DIR, *filepath)

View File

@ -43,6 +43,8 @@ class ModelConfig:
version. version.
max_model_len: Maximum length of a sequence (including prompt and max_model_len: Maximum length of a sequence (including prompt and
output). If None, will be derived from the model. output). If None, will be derived from the model.
quantization: Quantization method that was used to quantize the model
weights. If None, we assume the model weights are not quantized.
""" """
def __init__( def __init__(
@ -57,6 +59,7 @@ class ModelConfig:
seed: int, seed: int,
revision: Optional[str], revision: Optional[str],
max_model_len: Optional[int] = None, max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
) -> None: ) -> None:
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
@ -66,11 +69,13 @@ class ModelConfig:
self.load_format = load_format self.load_format = load_format
self.seed = seed self.seed = seed
self.revision = revision self.revision = revision
self.quantization = quantization
self.hf_config = get_config(model, trust_remote_code, revision) self.hf_config = get_config(model, trust_remote_code, revision)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self._verify_load_format() self._verify_load_format()
self._verify_tokenizer_mode() self._verify_tokenizer_mode()
self._verify_quantization()
self.max_model_len = None self.max_model_len = None
if max_model_len is not None: if max_model_len is not None:
derived_max_model_len = self.get_max_model_len() derived_max_model_len = self.get_max_model_len()
@ -100,6 +105,17 @@ class ModelConfig:
"either 'auto' or 'slow'.") "either 'auto' or 'slow'.")
self.tokenizer_mode = tokenizer_mode self.tokenizer_mode = tokenizer_mode
def _verify_quantization(self) -> None:
supported_quantization = ["awq"]
if self.quantization is None:
return
quantization = self.quantization.lower()
if quantization not in supported_quantization:
raise ValueError(
f"Unknown quantization: {self.quantization}. Must be one of "
f"{supported_quantization}.")
self.quantization = quantization
def verify_with_parallel_config( def verify_with_parallel_config(
self, self,
parallel_config: "ParallelConfig", parallel_config: "ParallelConfig",

View File

@ -29,6 +29,7 @@ class EngineArgs:
max_num_seqs: int = 256 max_num_seqs: int = 256
disable_log_stats: bool = False disable_log_stats: bool = False
revision: Optional[str] = None revision: Optional[str] = None
quantization: Optional[str] = None
def __post_init__(self): def __post_init__(self):
if self.tokenizer is None: if self.tokenizer is None:
@ -88,7 +89,6 @@ class EngineArgs:
'a numpy cache to speed up the loading. ' 'a numpy cache to speed up the loading. '
'"dummy" will initialize the weights with random values, ' '"dummy" will initialize the weights with random values, '
'which is mainly for profiling.') 'which is mainly for profiling.')
# TODO(woosuk): Support FP32.
parser.add_argument( parser.add_argument(
'--dtype', '--dtype',
type=str, type=str,
@ -150,6 +150,13 @@ class EngineArgs:
parser.add_argument('--disable-log-stats', parser.add_argument('--disable-log-stats',
action='store_true', action='store_true',
help='disable logging statistics') help='disable logging statistics')
# Quantization settings.
parser.add_argument('--quantization',
'-q',
type=str,
choices=['awq', None],
default=None,
help='Method used to quantize the weights')
return parser return parser
@classmethod @classmethod
@ -163,12 +170,11 @@ class EngineArgs:
def create_engine_configs( def create_engine_configs(
self, self,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]: ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
# Initialize the configs.
model_config = ModelConfig(self.model, self.tokenizer, model_config = ModelConfig(self.model, self.tokenizer,
self.tokenizer_mode, self.trust_remote_code, self.tokenizer_mode, self.trust_remote_code,
self.download_dir, self.load_format, self.download_dir, self.load_format,
self.dtype, self.seed, self.revision, self.dtype, self.seed, self.revision,
self.max_model_len) self.max_model_len, self.quantization)
cache_config = CacheConfig(self.block_size, cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization, self.gpu_memory_utilization,
self.swap_space) self.swap_space)

View File

@ -80,6 +80,7 @@ class LLMEngine:
f"download_dir={model_config.download_dir!r}, " f"download_dir={model_config.download_dir!r}, "
f"load_format={model_config.load_format}, " f"load_format={model_config.load_format}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"quantization={model_config.quantization}, "
f"seed={model_config.seed})") f"seed={model_config.seed})")
# TODO(woosuk): Print more configs in debug mode. # TODO(woosuk): Print more configs in debug mode.

View File

@ -0,0 +1,37 @@
from vllm.model_executor.layers.quantized_linear.awq import (
AWQColumnParallelLinear, AWQRowParallelLinear)
from vllm.model_executor.parallel_utils.tensor_parallel import (
ColumnParallelLinear, RowParallelLinear)
_QUANTIZED_LINEAR_REGISTRY = {
"awq": (AWQColumnParallelLinear, AWQRowParallelLinear),
}
class ParallelLinear:
@classmethod
def column(cls, *args, **kwargs) -> ColumnParallelLinear:
quant_config = kwargs.get("quant_config", None)
if quant_config is None:
return ColumnParallelLinear(*args, **kwargs)
name = quant_config.get_name()
if name not in _QUANTIZED_LINEAR_REGISTRY:
raise ValueError(f"No quantized linear is found for {name}")
quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][0]
return quant_linear_cls(*args, **kwargs)
@classmethod
def row(cls, *args, **kwargs) -> RowParallelLinear:
quant_config = kwargs.get("quant_config", None)
if quant_config is None:
return RowParallelLinear(*args, **kwargs)
name = quant_config.get_name()
if name not in _QUANTIZED_LINEAR_REGISTRY:
raise ValueError(f"No quantized linear is found for {name}")
quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][1]
return quant_linear_cls(*args, **kwargs)

View File

@ -0,0 +1,102 @@
from typing import Optional
import torch
from torch.nn.parameter import Parameter
from vllm import quantization_ops
from vllm.model_executor.parallel_utils.tensor_parallel.layers import (
ColumnParallelLinear, RowParallelLinear)
class AWQColumnParallelLinear(ColumnParallelLinear):
def create_weights(self, dtype: torch.dtype) -> None:
assert self.input_size % self.quant_config.weight_bits == 0
assert (self.output_size_per_partition %
self.quant_config.pack_factor == 0)
self.qweight = Parameter(
torch.empty(
self.input_size,
self.output_size_per_partition //
self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.qzeros = Parameter(
torch.empty(
self.input_size // self.quant_config.group_size,
self.output_size_per_partition //
self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.scales = Parameter(
torch.empty(
self.input_size // self.quant_config.group_size,
self.output_size_per_partition,
device="cuda",
dtype=dtype,
),
requires_grad=False,
)
def apply_weights(
self,
x: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor)
reshaped_x = x.reshape(-1, x.shape[-1])
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
self.qzeros, pack_factor)
if bias is not None:
out = out + bias
return out.reshape(out_shape)
class AWQRowParallelLinear(RowParallelLinear):
def create_weights(self, dtype: torch.dtype) -> None:
assert (self.input_size_per_partition %
self.quant_config.weight_bits == 0)
assert self.output_size % self.quant_config.pack_factor == 0
self.qweight = Parameter(
torch.empty(
self.input_size_per_partition,
self.output_size // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.qzeros = Parameter(
torch.empty(
self.input_size_per_partition // self.quant_config.group_size,
self.output_size // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.scales = Parameter(
torch.empty(
self.input_size_per_partition // self.quant_config.group_size,
self.output_size,
device="cuda",
dtype=dtype,
),
requires_grad=False,
)
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor)
reshaped_x = x.reshape(-1, x.shape[-1])
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
self.qzeros, pack_factor)
return out.reshape(out_shape)

View File

@ -8,7 +8,8 @@ from transformers import PretrainedConfig
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.model_executor.models import * # pylint: disable=wildcard-import from vllm.model_executor.models import * # pylint: disable=wildcard-import
from vllm.model_executor.weight_utils import initialize_dummy_weights from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights)
# TODO(woosuk): Lazy-load the model classes. # TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY = { _MODEL_REGISTRY = {
@ -30,6 +31,11 @@ _MODEL_REGISTRY = {
"RWForCausalLM": FalconForCausalLM, "RWForCausalLM": FalconForCausalLM,
} }
# FIXME(woosuk): Remove this once all models support quantization.
_MODEL_CLASSES_SUPPORT_QUANTIZATION = [
LlamaForCausalLM,
]
@contextlib.contextmanager @contextlib.contextmanager
def _set_default_torch_dtype(dtype: torch.dtype): def _set_default_torch_dtype(dtype: torch.dtype):
@ -52,10 +58,30 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
def get_model(model_config: ModelConfig) -> nn.Module: def get_model(model_config: ModelConfig) -> nn.Module:
model_class = _get_model_architecture(model_config.hf_config) model_class = _get_model_architecture(model_config.hf_config)
# Get the quantization config.
quant_config = None
if model_config.quantization is not None:
if model_class not in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
raise ValueError(
f"Quantization is not supported for {model_class}.")
quant_config = get_quant_config(model_config.quantization,
model_config.model,
model_config.download_dir)
supported_dtypes = quant_config.get_supported_act_dtypes()
if model_config.dtype not in supported_dtypes:
raise ValueError(
f"{model_config.dtype} is not supported for quantization "
f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}")
with _set_default_torch_dtype(model_config.dtype): with _set_default_torch_dtype(model_config.dtype):
# Create a model instance. # Create a model instance.
# The weights will be initialized as empty tensors. # The weights will be initialized as empty tensors.
model = model_class(model_config.hf_config) if model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
model = model_class(model_config.hf_config, quant_config)
else:
model = model_class(model_config.hf_config)
if model_config.load_format == "dummy": if model_config.load_format == "dummy":
model = model.cuda() model = model.cuda()
# NOTE(woosuk): For accurate performance evaluation, we assign # NOTE(woosuk): For accurate performance evaluation, we assign

View File

@ -36,13 +36,15 @@ from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import ( from vllm.model_executor.layers.quantized_linear import ParallelLinear
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab,
hf_model_weights_iterator)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import ( from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear) VocabParallelEmbedding)
from vllm.model_executor.quantization_utils import QuantizationConfig
from vllm.model_executor.weight_utils import (
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -55,18 +57,21 @@ class LlamaMLP(nn.Module):
hidden_size: int, hidden_size: int,
intermediate_size: int, intermediate_size: int,
hidden_act: str, hidden_act: str,
): quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__() super().__init__()
self.gate_up_proj = ColumnParallelLinear(hidden_size, self.gate_up_proj = ParallelLinear.column(hidden_size,
2 * intermediate_size, 2 * intermediate_size,
bias=False, bias=False,
gather_output=False, gather_output=False,
perform_initialization=False) perform_initialization=False,
self.down_proj = RowParallelLinear(intermediate_size, quant_config=quant_config)
hidden_size, self.down_proj = ParallelLinear.row(intermediate_size,
bias=False, hidden_size,
input_is_parallel=True, bias=False,
perform_initialization=False) input_is_parallel=True,
perform_initialization=False,
quant_config=quant_config)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
@ -87,7 +92,8 @@ class LlamaAttention(nn.Module):
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
rope_theta: float = 10000, rope_theta: float = 10000,
): quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
@ -103,20 +109,22 @@ class LlamaAttention(nn.Module):
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.qkv_proj = ColumnParallelLinear( self.qkv_proj = ParallelLinear.column(
hidden_size, hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) * (self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim, self.head_dim,
bias=False, bias=False,
gather_output=False, gather_output=False,
perform_initialization=False, perform_initialization=False,
quant_config=quant_config,
) )
self.o_proj = RowParallelLinear( self.o_proj = ParallelLinear.row(
self.total_num_heads * self.head_dim, self.total_num_heads * self.head_dim,
hidden_size, hidden_size,
bias=False, bias=False,
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False, perform_initialization=False,
quant_config=quant_config,
) )
self.attn = PagedAttentionWithRoPE(self.num_heads, self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_dim, self.head_dim,
@ -144,7 +152,11 @@ class LlamaAttention(nn.Module):
class LlamaDecoderLayer(nn.Module): class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig): def __init__(
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0 # Requires transformers > 4.32.0
@ -154,11 +166,13 @@ class LlamaDecoderLayer(nn.Module):
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta, rope_theta=rope_theta,
quant_config=quant_config,
) )
self.mlp = LlamaMLP( self.mlp = LlamaMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config,
) )
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps) eps=config.rms_norm_eps)
@ -195,7 +209,11 @@ class LlamaDecoderLayer(nn.Module):
class LlamaModel(nn.Module): class LlamaModel(nn.Module):
def __init__(self, config: LlamaConfig): def __init__(
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
@ -205,7 +223,8 @@ class LlamaModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
vocab_size, config.hidden_size, perform_initialization=False) vocab_size, config.hidden_size, perform_initialization=False)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers) LlamaDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
]) ])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -237,16 +256,23 @@ class LlamaModel(nn.Module):
class LlamaForCausalLM(nn.Module): class LlamaForCausalLM(nn.Module):
def __init__(self, config): def __init__(
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
self.model = LlamaModel(config) self.quant_config = quant_config
self.model = LlamaModel(config, quant_config)
vocab_size = ((config.vocab_size + 63) // 64) * 64 vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ColumnParallelLinear(config.hidden_size, # NOTE: The LM head is not quantized.
vocab_size, self.lm_head = ParallelLinear.column(config.hidden_size,
bias=False, vocab_size,
gather_output=False, bias=False,
perform_initialization=False) gather_output=False,
perform_initialization=False,
quant_config=None)
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
def forward( def forward(
@ -263,16 +289,28 @@ class LlamaForCausalLM(nn.Module):
input_metadata) input_metadata)
return next_tokens return next_tokens
_column_parallel_weights = [ _column_parallel_layers = []
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight" _row_parallel_layers = ["o_proj", "down_proj"]
]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None): revision: Optional[str] = None):
if self.quant_config is None:
weight_suffixes = ["weight"]
else:
weight_suffixes = self.quant_config.get_tp_tensor_names()
column_parallel_weights: List[str] = []
for layer in self._column_parallel_layers:
for suffix in weight_suffixes:
column_parallel_weights.append(f"{layer}.{suffix}")
row_parallel_weights: List[str] = []
for layer in self._row_parallel_layers:
for suffix in weight_suffixes:
row_parallel_weights.append(f"{layer}.{suffix}")
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size) q_proj_shard_size = (self.config.hidden_size // tp_size)
@ -293,11 +331,25 @@ class LlamaForCausalLM(nn.Module):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
is_packed = False
is_transposed = False
if self.quant_config is not None:
is_packed = self.quant_config.is_packed(name)
is_transposed = self.quant_config.is_transposed(name)
if is_transposed:
loaded_weight = loaded_weight.T
is_attention_weight = False is_attention_weight = False
for weight_name, shard_size, offset in attention_weight_specs: for weight_name, shard_size, offset in attention_weight_specs:
if weight_name not in name: if weight_name not in name:
continue continue
param = state_dict[name.replace(weight_name, "qkv_proj")] param = state_dict[name.replace(weight_name, "qkv_proj")]
if is_transposed:
param = param.T
if is_packed:
shard_size //= self.quant_config.pack_factor
offset //= self.quant_config.pack_factor
loaded_weight = loaded_weight[ loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size * shard_size * tensor_model_parallel_rank:shard_size *
@ -316,6 +368,9 @@ class LlamaForCausalLM(nn.Module):
if weight_name not in name: if weight_name not in name:
continue continue
param = state_dict[name.replace(weight_name, "gate_up_proj")] param = state_dict[name.replace(weight_name, "gate_up_proj")]
if is_transposed:
param = param.T
shard_size = param.shape[0] // 2 shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[ loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size * shard_size * tensor_model_parallel_rank:shard_size *
@ -330,6 +385,8 @@ class LlamaForCausalLM(nn.Module):
continue continue
param = state_dict[name] param = state_dict[name]
if is_transposed:
param = param.T
if "embed_tokens" in name or "lm_head" in name: if "embed_tokens" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight, load_padded_tensor_parallel_vocab(param, loaded_weight,
@ -337,6 +394,6 @@ class LlamaForCausalLM(nn.Module):
continue continue
load_tensor_parallel_weights(param, loaded_weight, name, load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights, column_parallel_weights,
self._row_parallel_weights, row_parallel_weights,
tensor_model_parallel_rank) tensor_model_parallel_rank)

View File

@ -4,7 +4,7 @@
# Parts of the code here are adapted from PyTorch # Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch # repo: https://github.com/pytorch/pytorch
from typing import Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -16,13 +16,11 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from .mappings import ( from .mappings import (
copy_to_tensor_model_parallel_region,
gather_from_tensor_model_parallel_region, gather_from_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region, reduce_from_tensor_model_parallel_region,
scatter_to_tensor_model_parallel_region, scatter_to_tensor_model_parallel_region,
) )
from .random import get_cuda_rng_tracker
from .utils import ( from .utils import (
divide, divide,
VocabUtility, VocabUtility,
@ -65,59 +63,6 @@ def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
maybe_copy(attribute) maybe_copy(attribute)
def _initialize_affine_weight_gpu(weight, init_method,
partition_dim, stride=1):
"""Initialize affine weight for model parallel on GPU."""
set_tensor_model_parallel_attributes(tensor=weight,
is_parallel=True,
dim=partition_dim,
stride=stride)
with get_cuda_rng_tracker().fork():
init_method(weight)
def _initialize_affine_weight_cpu(weight, output_size, input_size,
per_partition_size, partition_dim,
init_method, stride=1,
return_master_weight=False,
*, params_dtype=None):
"""Initialize affine weight for model parallel.
Build the master weight on all processes and scatter
the relevant chunk."""
set_tensor_model_parallel_attributes(tensor=weight,
is_parallel=True,
dim=partition_dim,
stride=stride)
if params_dtype is None:
params_dtype = torch.get_default_dtype()
# Initialize master weight
master_weight = torch.empty(output_size, input_size,
dtype=torch.float,
requires_grad=False)
init_method(master_weight)
master_weight = master_weight.to(dtype=params_dtype)
# Split and copy
per_partition_per_stride_size = divide(per_partition_size, stride)
weight_list = torch.split(master_weight, per_partition_per_stride_size,
dim=partition_dim)
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
my_weight_list = weight_list[rank::world_size]
with torch.no_grad():
torch.cat(my_weight_list, dim=partition_dim, out=weight)
if return_master_weight:
return master_weight
return None
class VocabParallelEmbedding(torch.nn.Module): class VocabParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension. """Embedding parallelized in the vocabulary dimension.
@ -140,6 +85,9 @@ class VocabParallelEmbedding(torch.nn.Module):
use_cpu_initialization: bool=False, use_cpu_initialization: bool=False,
perform_initialization: bool=True): perform_initialization: bool=True):
super(VocabParallelEmbedding, self).__init__() super(VocabParallelEmbedding, self).__init__()
assert not perform_initialization
assert not use_cpu_initialization
# Keep the input dimensions. # Keep the input dimensions.
self.num_embeddings = num_embeddings self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
@ -162,24 +110,10 @@ class VocabParallelEmbedding(torch.nn.Module):
self.num_embeddings_per_partition = self.vocab_end_index - \ self.num_embeddings_per_partition = self.vocab_end_index - \
self.vocab_start_index self.vocab_start_index
# Allocate weights and initialize. self.weight = Parameter(torch.empty(
if use_cpu_initialization: self.num_embeddings_per_partition, self.embedding_dim,
self.weight = Parameter(torch.empty( device=torch.cuda.current_device(), dtype=params_dtype))
self.num_embeddings_per_partition, self.embedding_dim,
dtype=params_dtype))
if perform_initialization:
_initialize_affine_weight_cpu(
self.weight, self.num_embeddings, self.embedding_dim,
self.num_embeddings_per_partition, 0, init_method,
params_dtype=params_dtype)
else:
self.weight = Parameter(torch.empty(
self.num_embeddings_per_partition, self.embedding_dim,
device=torch.cuda.current_device(), dtype=params_dtype))
if perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=1)
def forward(self, input_): def forward(self, input_):
if self.tensor_model_parallel_size > 1: if self.tensor_model_parallel_size > 1:
# Build the mask. # Build the mask.
@ -239,8 +173,11 @@ class ColumnParallelLinear(torch.nn.Module):
params_dtype=None, params_dtype=None,
use_cpu_initialization=False, use_cpu_initialization=False,
perform_initialization=True, perform_initialization=True,
quant_config=None,
): ):
super(ColumnParallelLinear, self).__init__() super(ColumnParallelLinear, self).__init__()
assert not perform_initialization
assert not use_cpu_initialization
# Keep input parameters # Keep input parameters
self.input_size = input_size self.input_size = input_size
@ -250,6 +187,7 @@ class ColumnParallelLinear(torch.nn.Module):
self.world_size = get_tensor_model_parallel_world_size() self.world_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, self.world_size) self.output_size_per_partition = divide(output_size, self.world_size)
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.quant_config = quant_config
if params_dtype is None: if params_dtype is None:
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
@ -257,33 +195,13 @@ class ColumnParallelLinear(torch.nn.Module):
# Parameters. # Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result # Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose. # we allocate the transpose.
# Initialize weight. self.create_weights(params_dtype)
if use_cpu_initialization:
self.weight = Parameter(torch.empty(self.output_size_per_partition,
self.input_size,
dtype=params_dtype))
if perform_initialization:
self.master_weight = _initialize_affine_weight_cpu(
self.weight, self.output_size, self.input_size,
self.output_size_per_partition, 0, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test)
else:
self.weight = Parameter(torch.empty(
self.output_size_per_partition, self.input_size,
device=torch.cuda.current_device(), dtype=params_dtype))
if perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=stride)
if bias: if bias:
if use_cpu_initialization: self.bias = Parameter(torch.empty(
self.bias = Parameter(torch.empty( self.output_size_per_partition,
self.output_size_per_partition, dtype=params_dtype)) device=torch.cuda.current_device(),
else: dtype=params_dtype))
self.bias = Parameter(torch.empty(
self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype))
set_tensor_model_parallel_attributes(self.bias, True, 0, stride) set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
# Always initialize bias to zero. # Always initialize bias to zero.
with torch.no_grad(): with torch.no_grad():
@ -291,6 +209,17 @@ class ColumnParallelLinear(torch.nn.Module):
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
def create_weights(self, dtype: torch.dtype) -> None:
self.weight = Parameter(torch.empty(
self.output_size_per_partition, self.input_size,
device=torch.cuda.current_device(), dtype=dtype))
def apply_weights(
self,
x: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
return F.linear(x, self.weight, bias)
def forward(self, input_): def forward(self, input_):
"""Forward of ColumnParallelLinear """Forward of ColumnParallelLinear
@ -306,7 +235,7 @@ class ColumnParallelLinear(torch.nn.Module):
input_parallel = input_ input_parallel = input_
# Matrix multiply. # Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight, bias) output_parallel = self.apply_weights(input_parallel, bias)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
output = gather_from_tensor_model_parallel_region(output_parallel) output = gather_from_tensor_model_parallel_region(output_parallel)
@ -361,8 +290,11 @@ class RowParallelLinear(torch.nn.Module):
use_cpu_initialization=False, use_cpu_initialization=False,
perform_initialization=True, perform_initialization=True,
reduce_results=True, reduce_results=True,
quant_config=None,
): ):
super(RowParallelLinear, self).__init__() super(RowParallelLinear, self).__init__()
assert not perform_initialization
assert not use_cpu_initialization
# Keep input parameters # Keep input parameters
self.input_size = input_size self.input_size = input_size
@ -376,47 +308,32 @@ class RowParallelLinear(torch.nn.Module):
self.world_size = get_tensor_model_parallel_world_size() self.world_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.world_size) self.input_size_per_partition = divide(input_size, self.world_size)
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.quant_config = quant_config
self.create_weights(params_dtype)
if not reduce_results and (bias and not skip_bias_add): if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the " raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results") "results can lead to incorrect results")
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
# Initialize weight.
if use_cpu_initialization:
self.weight = Parameter(torch.empty(self.output_size,
self.input_size_per_partition,
dtype=params_dtype))
if perform_initialization:
self.master_weight = _initialize_affine_weight_cpu(
self.weight, self.output_size, self.input_size,
self.input_size_per_partition, 1, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test,
params_dtype=params_dtype)
else:
self.weight = Parameter(torch.empty(
self.output_size, self.input_size_per_partition,
device=torch.cuda.current_device(), dtype=params_dtype))
if perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=1, stride=stride)
if bias: if bias:
if use_cpu_initialization: self.bias = Parameter(torch.empty(
self.bias = Parameter(torch.empty(self.output_size, self.output_size, device=torch.cuda.current_device(),
dtype=params_dtype)) dtype=params_dtype))
else:
self.bias = Parameter(torch.empty(
self.output_size, device=torch.cuda.current_device(),
dtype=params_dtype))
# Always initialize bias to zero. # Always initialize bias to zero.
with torch.no_grad(): with torch.no_grad():
self.bias.zero_() self.bias.zero_()
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.weight_t = self.weight.t()
def create_weights(self, dtype: torch.dtype) -> None:
self.weight = Parameter(torch.empty(
self.output_size, self.input_size_per_partition,
device=torch.cuda.current_device(), dtype=dtype))
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight)
def forward(self, input_): def forward(self, input_):
"""Forward of RowParallelLinear """Forward of RowParallelLinear
@ -434,7 +351,7 @@ class RowParallelLinear(torch.nn.Module):
else: else:
input_parallel = scatter_to_tensor_model_parallel_region(input_) input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply. # Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight) output_parallel = self.apply_weights(input_parallel)
if self.reduce_results and self.world_size > 1: if self.reduce_results and self.world_size > 1:
output_ = reduce_from_tensor_model_parallel_region(output_parallel) output_ = reduce_from_tensor_model_parallel_region(output_parallel)
else: else:

View File

@ -0,0 +1,20 @@
from typing import Type
from vllm.model_executor.quantization_utils.awq import AWQConfig
from vllm.model_executor.quantization_utils.base import QuantizationConfig
_QUANTIZATION_REGISTRY = {
"awq": AWQConfig,
}
def get_quant_class(quantization: str) -> Type[QuantizationConfig]:
if quantization not in _QUANTIZATION_REGISTRY:
raise ValueError(f"Invalid quantization method: {quantization}")
return _QUANTIZATION_REGISTRY[quantization]
__all__ = [
"QuantizationConfig",
"get_quant_class",
]

View File

@ -0,0 +1,67 @@
from typing import Any, Dict, List
import torch
from vllm.model_executor.quantization_utils.base import QuantizationConfig
class AWQConfig(QuantizationConfig):
"""Config class for AWQ.
Reference: https://arxiv.org/abs/2306.00978
"""
def __init__(
self,
weight_bits: int,
group_size: int,
zero_point: bool,
) -> None:
self.weight_bits = weight_bits
self.group_size = group_size
self.zero_point = zero_point
if self.weight_bits != 4:
raise ValueError(
"Currently, only 4-bit weight quantization is supported for "
f"AWQ, but got {self.weight_bits} bits.")
self.pack_factor = 32 // self.weight_bits
def __repr__(self) -> str:
return (f"AWQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"zero_point={self.zero_point})")
@classmethod
def get_name(cls) -> str:
return "awq"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]
@classmethod
def get_config_filenames(cls) -> List[str]:
return [
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
"quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq # pylint: disable=line-too-long
]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"])
return cls(weight_bits, group_size, zero_point)
@classmethod
def get_packed_tensor_names(cls) -> List[str]:
return ["qweight", "qzeros"]
@classmethod
def get_transposed_tensor_names(cls) -> List[str]:
return ["qweight", "qzeros", "scales"]
@classmethod
def get_tp_tensor_names(cls) -> List[str]:
return ["qweight", "qzeros", "scales"]

View File

@ -0,0 +1,65 @@
from typing import Any, Dict, List
import torch
class QuantizationConfig:
@classmethod
def get_name(cls) -> str:
"""Name of the quantization method."""
raise NotImplementedError
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
"""List of supported activation dtypes."""
raise NotImplementedError
@classmethod
def get_config_filenames(cls) -> List[str]:
"""List of filenames to search for in the model directory."""
raise NotImplementedError
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
"""Create a config class from the model's quantization config."""
raise NotImplementedError
@staticmethod
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
"""Get a value from the model's quantization config."""
for key in keys:
if key in config:
return config[key]
raise ValueError(f"Cannot find any of {keys} in the model's "
"quantization config.")
@classmethod
def get_packed_tensor_names(cls) -> List[str]:
raise NotImplementedError
@classmethod
def is_packed(cls, tensor_name: str) -> bool:
"""Returns True if a tensor is packed.
A tensor is considered packed if each element in the tensor is a
packed representation of multiple elements in the original tensor.
For example, an INT32 element in the tensor may represent 8 INT4
elements in the original tensor.
"""
return any(tag in tensor_name for tag in cls.get_packed_tensor_names())
@classmethod
def get_transposed_tensor_names(cls) -> List[str]:
raise NotImplementedError
@classmethod
def is_transposed(cls, tensor_name: str) -> bool:
"""Returns True if a tensor is transposed relative to nn.Linear.weight.
"""
return any(tag in tensor_name
for tag in cls.get_transposed_tensor_names())
@classmethod
def get_tp_tensor_names(cls) -> List[str]:
raise NotImplementedError

View File

@ -4,7 +4,7 @@ import glob
import json import json
import os import os
from collections import defaultdict from collections import defaultdict
from typing import Iterator, List, Optional, Tuple, Any from typing import Any, Iterator, List, Optional, Tuple
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from safetensors.torch import load_file, save_file, safe_open from safetensors.torch import load_file, save_file, safe_open
@ -13,6 +13,8 @@ import torch
from tqdm.auto import tqdm from tqdm.auto import tqdm
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.quantization_utils import get_quant_class
from vllm.model_executor.quantization_utils.base import QuantizationConfig
logger = init_logger(__name__) logger = init_logger(__name__)
@ -44,7 +46,7 @@ def _shared_pointers(tensors):
def convert_bin_to_safetensor_file( def convert_bin_to_safetensor_file(
pt_filename: str, pt_filename: str,
sf_filename: str, sf_filename: str,
): ) -> None:
loaded = torch.load(pt_filename, map_location="cpu") loaded = torch.load(pt_filename, map_location="cpu")
if "state_dict" in loaded: if "state_dict" in loaded:
loaded = loaded["state_dict"] loaded = loaded["state_dict"]
@ -78,16 +80,55 @@ def convert_bin_to_safetensor_file(
raise RuntimeError(f"The output tensors do not match for key {k}") raise RuntimeError(f"The output tensors do not match for key {k}")
# TODO(woosuk): Move this to other place.
def get_quant_config(
quantization: str,
model_name_or_path: str,
cache_dir: Optional[str] = None,
) -> QuantizationConfig:
is_local = os.path.isdir(model_name_or_path)
if not is_local:
# Download the config files.
with get_lock(model_name_or_path, cache_dir):
hf_folder = snapshot_download(model_name_or_path,
allow_patterns="*.json",
cache_dir=cache_dir,
tqdm_class=Disabledtqdm)
else:
hf_folder = model_name_or_path
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
quant_cls = get_quant_class(quantization)
quant_config_files = [
f for f in config_files if any(
f.endswith(x) for x in quant_cls.get_config_filenames())
]
if len(quant_config_files) == 0:
raise ValueError(f"Cannot find the config file for {quantization}")
if len(quant_config_files) > 1:
raise ValueError(f"Found multiple config files for {quantization}: "
f"{quant_config_files}")
quant_config_file = quant_config_files[0]
with open(quant_config_file, "r") as f:
config = json.load(f)
return quant_cls.from_config(config)
def prepare_hf_model_weights( def prepare_hf_model_weights(
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
use_safetensors: bool = False, use_safetensors: bool = False,
fall_back_to_pt: bool = True, fall_back_to_pt: bool = True,
revision: Optional[str] = None, revision: Optional[str] = None,
): ) -> Tuple[str, List[str], bool]:
# Download model weights from huggingface. # Download model weights from huggingface.
is_local = os.path.isdir(model_name_or_path) is_local = os.path.isdir(model_name_or_path)
allow_patterns = "*.safetensors" if use_safetensors else "*.bin" if use_safetensors:
allow_patterns = ["*.safetensors"]
else:
# Some quantized models use .pt files for storing the weights.
allow_patterns = ["*.bin", "*.pt"]
if not is_local: if not is_local:
# Use file lock to prevent multiple processes from # Use file lock to prevent multiple processes from
# downloading the same model weights at the same time. # downloading the same model weights at the same time.
@ -99,7 +140,9 @@ def prepare_hf_model_weights(
revision=revision) revision=revision)
else: else:
hf_folder = model_name_or_path hf_folder = model_name_or_path
hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns)) hf_weights_files: List[str] = []
for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
if not use_safetensors: if not use_safetensors:
hf_weights_files = [ hf_weights_files = [
x for x in hf_weights_files if not x.endswith("training_args.bin") x for x in hf_weights_files if not x.endswith("training_args.bin")