Implement AWQ quantization support for LLaMA (#1032)

Co-authored-by: Robert Irvine <robert@seamlessml.com>
Co-authored-by: root <rirv938@gmail.com>
Co-authored-by: Casper <casperbh.96@gmail.com>
Co-authored-by: julian-q <julianhquevedo@gmail.com>
This commit is contained in:
Woosuk Kwon 2023-09-16 00:03:37 -07:00 committed by GitHub
parent b9fe4616f9
commit e3e79e9e8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 1178 additions and 208 deletions

4
.gitignore vendored
View File

@ -173,3 +173,7 @@ cython_debug/
# Sphinx documentation
_build/
# vim swap files
*.swo
*.swp

View File

@ -18,6 +18,7 @@ def main(args: argparse.Namespace):
llm = LLM(
model=args.model,
tokenizer=args.tokenizer,
quantization=args.quantization,
tensor_parallel_size=args.tensor_parallel_size,
max_num_seqs=args.batch_size,
max_num_batched_tokens=args.batch_size * args.input_len,
@ -63,19 +64,28 @@ def main(args: argparse.Namespace):
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Benchmark the latency of processing a single batch of '
'requests till completion.')
'requests till completion.')
parser.add_argument('--model', type=str, default='facebook/opt-125m')
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=['awq', None],
default=None)
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32)
parser.add_argument('--output-len', type=int, default=128)
parser.add_argument('--batch-size', type=int, default=8)
parser.add_argument('--n', type=int, default=1,
parser.add_argument('--n',
type=int,
default=1,
help='Number of generated sequences per prompt.')
parser.add_argument('--use-beam-search', action='store_true')
parser.add_argument('--num-iters', type=int, default=3,
parser.add_argument('--num-iters',
type=int,
default=3,
help='Number of iterations to run.')
parser.add_argument('--trust-remote-code', action='store_true',
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
args = parser.parse_args()
main(args)

View File

@ -3,7 +3,7 @@ import argparse
import json
import random
import time
from typing import List, Tuple
from typing import List, Optional, Tuple
import torch
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
@ -22,15 +22,10 @@ def sample_requests(
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [
data for data in dataset
if len(data["conversations"]) >= 2
]
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]
dataset = [(data["conversations"][0]["value"],
data["conversations"][1]["value"]) for data in dataset]
# Tokenize the prompts and completions.
prompts = [prompt for prompt, _ in dataset]
@ -63,6 +58,7 @@ def run_vllm(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: str,
quantization: Optional[str],
tensor_parallel_size: int,
seed: int,
n: int,
@ -72,6 +68,7 @@ def run_vllm(
llm = LLM(
model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
@ -111,8 +108,8 @@ def run_hf(
trust_remote_code: bool,
) -> float:
assert not use_beam_search
llm = AutoModelForCausalLM.from_pretrained(model,
torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
llm = AutoModelForCausalLM.from_pretrained(
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
if llm.config.model_type == "llama":
# To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token
@ -132,13 +129,14 @@ def run_hf(
if len(batch) < max_batch_size and i != len(requests) - 1:
# Check if we can add more requests to the batch.
_, next_prompt_len, next_output_len = requests[i + 1]
if (max(max_prompt_len, next_prompt_len) + max(
max_output_len, next_output_len)) <= 2048:
if (max(max_prompt_len, next_prompt_len) +
max(max_output_len, next_output_len)) <= 2048:
# We can add more requests to the batch.
continue
# Generate the sequences.
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
input_ids = tokenizer(batch, return_tensors="pt",
padding=True).input_ids
llm_outputs = llm.generate(
input_ids=input_ids.cuda(),
do_sample=not use_beam_search,
@ -165,44 +163,58 @@ def main(args: argparse.Namespace):
random.seed(args.seed)
# Sample the requests.
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
tokenizer = get_tokenizer(args.tokenizer,
trust_remote_code=args.trust_remote_code)
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
if args.backend == "vllm":
elapsed_time = run_vllm(
requests, args.model, args.tokenizer, args.tensor_parallel_size,
args.seed, args.n, args.use_beam_search, args.trust_remote_code)
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
args.quantization, args.tensor_parallel_size,
args.seed, args.n, args.use_beam_search,
args.trust_remote_code)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(
requests, args.model, tokenizer, args.n, args.use_beam_search,
args.hf_max_batch_size, args.trust_remote_code)
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
args.use_beam_search, args.hf_max_batch_size,
args.trust_remote_code)
else:
raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum(
prompt_len + output_len
for _, prompt_len, output_len in requests
)
total_num_tokens = sum(prompt_len + output_len
for _, prompt_len, output_len in requests)
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--backend", type=str, choices=["vllm", "hf"],
parser.add_argument("--backend",
type=str,
choices=["vllm", "hf"],
default="vllm")
parser.add_argument("--dataset", type=str, required=True,
parser.add_argument("--dataset",
type=str,
required=True,
help="Path to the dataset.")
parser.add_argument("--model", type=str, default="facebook/opt-125m")
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=['awq', None],
default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n", type=int, default=1,
parser.add_argument("--n",
type=int,
default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--num-prompts", type=int, default=1000,
parser.add_argument("--num-prompts",
type=int,
default=1000,
help="Number of prompts to process.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--hf-max-batch-size", type=int, default=None,
parser.add_argument("--hf-max-batch-size",
type=int,
default=None,
help="Maximum batch size for HF backend.")
parser.add_argument('--trust-remote-code',
action='store_true',
@ -215,6 +227,8 @@ if __name__ == "__main__":
elif args.backend == "hf":
if args.hf_max_batch_size is None:
raise ValueError("HF max batch size is required for HF backend.")
if args.quantization is not None:
raise ValueError("Quantization is only for vLLM backend.")
if args.tokenizer is None:
args.tokenizer = args.model

15
csrc/quantization.cpp Normal file
View File

@ -0,0 +1,15 @@
#include <torch/extension.h>
torch::Tensor awq_gemm(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"awq_gemm",
&awq_gemm,
"Quantized GEMM for AWQ");
}

View File

@ -0,0 +1,79 @@
/*
Adapted from https://github.com/mit-han-lab/llm-awq
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#pragma once
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
{
uint4 result;
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
// immediately before required.
const uint32_t top_i4s = i4s >> 8;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
// half2 ctor. In this case, I chose performance reliability over code readability.
// This is the half2 {1032, 1032} represented as an integer.
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
// This is the half2 {-72, -72} represented as an integer.
// static constexpr uint32_t NEG_72 = 0xd480d480;
// Haotian: Let's use {-64, -64}.
static constexpr uint32_t NEG_64 = 0xd400d400;
// Finally, we construct the output numbers.
// Convert elt_01
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_23
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Convert elt_45
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
return result;
}

View File

@ -0,0 +1,477 @@
/*
Adapted from https://github.com/mit-han-lab/llm-awq
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include "dequantize.cuh"
#include <cuda_fp16.h>
// Pack two half values.
static inline __device__ __host__ unsigned
__pack_half2(const half x, const half y) {
unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short *)&y);
return (v1 << 16) | v0;
}
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
{
static constexpr uint32_t ZERO = 0x0;
float C_warp[32];
__shared__ half A_shared[16 * (32 + 8)];
__shared__ half B_shared[32 * (128 + 8)];
__shared__ half scaling_factors_shared[128];
__shared__ half zeros_shared[128];
int j_factors1 = ((OC + 128 - 1) / 128);
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
half A_shared_warp[8];
half B_shared_warp[32];
for (int j_0_4_init = 0; j_0_4_init < 4; ++j_0_4_init) {
for (int i = 0; i < 8; ++i) {
C_warp[(j_0_4_init * 8) + i] = 0.0;
}
}
static constexpr int row_stride_warp = 32 * 8 / 32;
static constexpr int row_stride = 2 * 32 * 8 / 128;
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 128;
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
// bool wb_C_flag = (threadIdx.x / 4) < M;
half* A_ptr = A
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
+ (((int)threadIdx.x) % (32 / 8)) * 8;
int* B_ptr = B
+ ((int)threadIdx.y) * (OC / 8) * 2
+ (((int)threadIdx.x) / (128 / 8)) * (OC / 8)
+ (((int)blockIdx_y) % j_factors1) * (128 / 8)
+ (((int)threadIdx.x) % (128 / 8)) * 1;
// Why * 1 in the above line?
half* A_shared_ptr = A_shared
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
half* B_shared_ptr = B_shared
+ ((int)threadIdx.y) * (row_stride / 2) * (128 + 8)
+ (((int)threadIdx.x) / (128 / 8)) * (128 + 8)
+ (((int)threadIdx.x) % (128 / 8)) * 8;
int* zeros_ptr = zeros
+ (((int)blockIdx_y) % j_factors1) * (128 / 8)
+ ((int)threadIdx.x) % (128 / 8);
half* scaling_factors_ptr = scaling_factors
+ (((int)blockIdx_y) % j_factors1) * (128)
+ (((int)threadIdx.x) % (128 / 8)) * 8;
half* C_ptr = C
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim
+ (((int)blockIdx_y) % j_factors1) * 128
+ ((int)threadIdx.y) * 64
+ (((int)threadIdx.x) % 4) * 2;
// preload s.f. and zeros
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
__syncthreads();
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
if (ld_A_flag)
{
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
}
else
{
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
}
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
/*
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
}
*/
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) {
// B: 32 x 136 (128+8) float16
// each warp: 32 x 4
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
// - zero and * scale
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
/*
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
}
*/
// write back
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (128 + 8)) = B_loaded_fp16;
}
__syncthreads();
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
);
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
: "r"(addr)
);
}
for (int ax1_0 = 0; ax1_0 < 4; ++ax1_0) {
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(B_shared[(((k_0_1 * 2176) + (((int)threadIdx.y) * 64)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 136) + ((((int)threadIdx.x) >> 4) * 8))))
);
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
: "r"(addr)
);
}
}
for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) {
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
}
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
}
}
}
}
// TODO: Shang: Hoist loop invariance.
for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
for (int local_id = 0; local_id < 8; ++local_id) {
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
if (row_offset < M)
{
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
}
}
}
}
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
{
static constexpr uint32_t ZERO = 0x0;
float C_warp[32];
__shared__ half A_shared[16 * (32 + 8)];
__shared__ half B_shared[32 * (64 + 8)];
__shared__ half scaling_factors_shared[64];
__shared__ half zeros_shared[64];
int j_factors1 = ((OC + 64 - 1) / 64);
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
half A_shared_warp[8];
half B_shared_warp[16];
for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) {
for (int i = 0; i < 8; ++i) {
C_warp[(j_0_4_init * 8) + i] = 0.0;
}
}
static constexpr int row_stride_warp = 32 * 8 / 32;
static constexpr int row_stride = 2 * 32 * 8 / 64;
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64;
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
// bool wb_C_flag = (threadIdx.x / 4) < M;
half* A_ptr = A
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
+ (((int)threadIdx.x) % (32 / 8)) * 8;
int* B_ptr = B
+ ((int)threadIdx.y) * (OC / 8) * 4
+ (((int)threadIdx.x) / (64 / 8)) * (OC / 8)
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
+ (((int)threadIdx.x) % (64 / 8)) * 1;
// Why * 1 in the above line?
half* A_shared_ptr = A_shared
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
half* B_shared_ptr = B_shared
+ ((int)threadIdx.y) * (row_stride / 2) * (64 + 8)
+ (((int)threadIdx.x) / (64 / 8)) * (64 + 8)
+ (((int)threadIdx.x) % (64 / 8)) * 8;
int* zeros_ptr = zeros
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
+ ((int)threadIdx.x) % (64 / 8);
half* scaling_factors_ptr = scaling_factors
+ (((int)blockIdx_y) % j_factors1) * (64)
+ (((int)threadIdx.x) % (64 / 8)) * 8;
half* C_ptr = C
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim
+ (((int)blockIdx_y) % j_factors1) * 64
+ ((int)threadIdx.y) * 32
+ (((int)threadIdx.x) % 4) * 2;
// preload s.f. and zeros
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
__syncthreads();
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
if (ld_A_flag)
{
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
}
else
{
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
}
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
/*
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
}
*/
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) {
// B: 32 x 136 (128+8) float16
// each warp: 32 x 4
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
// - zero and * scale
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
/*
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
}
*/
// write back
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_fp16;
}
__syncthreads();
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1)
{
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
);
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
: "r"(addr)
);
}
for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0)
{
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8))))
);
__asm__ __volatile__(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];\n"
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
: "r"(addr)
);
}
}
for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4)
{
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
}
{
__asm__ __volatile__(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
}
}
}
}
// TODO: Shang: Hoist loop invariance.
for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) {
for (int local_id = 0; local_id < 8; ++local_id) {
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
if (row_offset < M)
{
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
}
}
}
}
// in_feats: M, IC [float16]
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
// scaling_factors: IC // G, OC [float16]
// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
// assume that batch_size < 16 for now
torch::Tensor awq_gemm(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int split_k_iters)
{
int num_in_feats = _in_feats.size(0);
int num_in_channels = _in_feats.size(1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
int num_out_feats = _out_feats.size(-2);
int num_out_channels = _out_feats.size(-1);
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
int group_size = num_in_channels / _scaling_factors.size(0);
if (num_out_channels % 64 != 0)
throw std::invalid_argument("OC is not multiple of cta_N = 64");
if (num_out_channels % 8 != 0)
throw std::invalid_argument("OC is not multiple of pack_num = 8");
if (group_size % 32 != 0)
throw std::invalid_argument("Group size should be a multiple of 32");
if (num_out_channels % group_size != 0)
throw std::invalid_argument("OC is not multiple of Group size");
if (num_out_channels % 128 == 0)
{
int j_factors1 = num_out_channels / 128 / 1;
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2);
gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
}
else if (num_out_channels % 64 == 0)
{
int j_factors1 = num_out_channels / 64 / 1;
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3 threads_per_block(32, 2);
gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>(
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
}
return _out_feats.sum(0);
}

View File

@ -146,6 +146,20 @@ activation_extension = CUDAExtension(
)
ext_modules.append(activation_extension)
# Quantization kernels.
quantization_extension = CUDAExtension(
name="vllm.quantization_ops",
sources=[
"csrc/quantization.cpp",
"csrc/quantization/awq/gemm_kernels.cu",
],
extra_compile_args={
"cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS,
},
)
ext_modules.append(quantization_extension)
def get_path(*filepath) -> str:
return os.path.join(ROOT_DIR, *filepath)

View File

@ -43,6 +43,8 @@ class ModelConfig:
version.
max_model_len: Maximum length of a sequence (including prompt and
output). If None, will be derived from the model.
quantization: Quantization method that was used to quantize the model
weights. If None, we assume the model weights are not quantized.
"""
def __init__(
@ -57,6 +59,7 @@ class ModelConfig:
seed: int,
revision: Optional[str],
max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
) -> None:
self.model = model
self.tokenizer = tokenizer
@ -66,11 +69,13 @@ class ModelConfig:
self.load_format = load_format
self.seed = seed
self.revision = revision
self.quantization = quantization
self.hf_config = get_config(model, trust_remote_code, revision)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self._verify_load_format()
self._verify_tokenizer_mode()
self._verify_quantization()
self.max_model_len = None
if max_model_len is not None:
derived_max_model_len = self.get_max_model_len()
@ -100,6 +105,17 @@ class ModelConfig:
"either 'auto' or 'slow'.")
self.tokenizer_mode = tokenizer_mode
def _verify_quantization(self) -> None:
supported_quantization = ["awq"]
if self.quantization is None:
return
quantization = self.quantization.lower()
if quantization not in supported_quantization:
raise ValueError(
f"Unknown quantization: {self.quantization}. Must be one of "
f"{supported_quantization}.")
self.quantization = quantization
def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",

View File

@ -29,6 +29,7 @@ class EngineArgs:
max_num_seqs: int = 256
disable_log_stats: bool = False
revision: Optional[str] = None
quantization: Optional[str] = None
def __post_init__(self):
if self.tokenizer is None:
@ -88,7 +89,6 @@ class EngineArgs:
'a numpy cache to speed up the loading. '
'"dummy" will initialize the weights with random values, '
'which is mainly for profiling.')
# TODO(woosuk): Support FP32.
parser.add_argument(
'--dtype',
type=str,
@ -150,6 +150,13 @@ class EngineArgs:
parser.add_argument('--disable-log-stats',
action='store_true',
help='disable logging statistics')
# Quantization settings.
parser.add_argument('--quantization',
'-q',
type=str,
choices=['awq', None],
default=None,
help='Method used to quantize the weights')
return parser
@classmethod
@ -163,12 +170,11 @@ class EngineArgs:
def create_engine_configs(
self,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
# Initialize the configs.
model_config = ModelConfig(self.model, self.tokenizer,
self.tokenizer_mode, self.trust_remote_code,
self.download_dir, self.load_format,
self.dtype, self.seed, self.revision,
self.max_model_len)
self.max_model_len, self.quantization)
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space)

View File

@ -80,6 +80,7 @@ class LLMEngine:
f"download_dir={model_config.download_dir!r}, "
f"load_format={model_config.load_format}, "
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
f"quantization={model_config.quantization}, "
f"seed={model_config.seed})")
# TODO(woosuk): Print more configs in debug mode.

View File

@ -0,0 +1,37 @@
from vllm.model_executor.layers.quantized_linear.awq import (
AWQColumnParallelLinear, AWQRowParallelLinear)
from vllm.model_executor.parallel_utils.tensor_parallel import (
ColumnParallelLinear, RowParallelLinear)
_QUANTIZED_LINEAR_REGISTRY = {
"awq": (AWQColumnParallelLinear, AWQRowParallelLinear),
}
class ParallelLinear:
@classmethod
def column(cls, *args, **kwargs) -> ColumnParallelLinear:
quant_config = kwargs.get("quant_config", None)
if quant_config is None:
return ColumnParallelLinear(*args, **kwargs)
name = quant_config.get_name()
if name not in _QUANTIZED_LINEAR_REGISTRY:
raise ValueError(f"No quantized linear is found for {name}")
quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][0]
return quant_linear_cls(*args, **kwargs)
@classmethod
def row(cls, *args, **kwargs) -> RowParallelLinear:
quant_config = kwargs.get("quant_config", None)
if quant_config is None:
return RowParallelLinear(*args, **kwargs)
name = quant_config.get_name()
if name not in _QUANTIZED_LINEAR_REGISTRY:
raise ValueError(f"No quantized linear is found for {name}")
quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][1]
return quant_linear_cls(*args, **kwargs)

View File

@ -0,0 +1,102 @@
from typing import Optional
import torch
from torch.nn.parameter import Parameter
from vllm import quantization_ops
from vllm.model_executor.parallel_utils.tensor_parallel.layers import (
ColumnParallelLinear, RowParallelLinear)
class AWQColumnParallelLinear(ColumnParallelLinear):
def create_weights(self, dtype: torch.dtype) -> None:
assert self.input_size % self.quant_config.weight_bits == 0
assert (self.output_size_per_partition %
self.quant_config.pack_factor == 0)
self.qweight = Parameter(
torch.empty(
self.input_size,
self.output_size_per_partition //
self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.qzeros = Parameter(
torch.empty(
self.input_size // self.quant_config.group_size,
self.output_size_per_partition //
self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.scales = Parameter(
torch.empty(
self.input_size // self.quant_config.group_size,
self.output_size_per_partition,
device="cuda",
dtype=dtype,
),
requires_grad=False,
)
def apply_weights(
self,
x: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor)
reshaped_x = x.reshape(-1, x.shape[-1])
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
self.qzeros, pack_factor)
if bias is not None:
out = out + bias
return out.reshape(out_shape)
class AWQRowParallelLinear(RowParallelLinear):
def create_weights(self, dtype: torch.dtype) -> None:
assert (self.input_size_per_partition %
self.quant_config.weight_bits == 0)
assert self.output_size % self.quant_config.pack_factor == 0
self.qweight = Parameter(
torch.empty(
self.input_size_per_partition,
self.output_size // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.qzeros = Parameter(
torch.empty(
self.input_size_per_partition // self.quant_config.group_size,
self.output_size // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
self.scales = Parameter(
torch.empty(
self.input_size_per_partition // self.quant_config.group_size,
self.output_size,
device="cuda",
dtype=dtype,
),
requires_grad=False,
)
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor)
reshaped_x = x.reshape(-1, x.shape[-1])
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
self.qzeros, pack_factor)
return out.reshape(out_shape)

View File

@ -8,7 +8,8 @@ from transformers import PretrainedConfig
from vllm.config import ModelConfig
from vllm.model_executor.models import * # pylint: disable=wildcard-import
from vllm.model_executor.weight_utils import initialize_dummy_weights
from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights)
# TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY = {
@ -30,6 +31,11 @@ _MODEL_REGISTRY = {
"RWForCausalLM": FalconForCausalLM,
}
# FIXME(woosuk): Remove this once all models support quantization.
_MODEL_CLASSES_SUPPORT_QUANTIZATION = [
LlamaForCausalLM,
]
@contextlib.contextmanager
def _set_default_torch_dtype(dtype: torch.dtype):
@ -52,10 +58,30 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
def get_model(model_config: ModelConfig) -> nn.Module:
model_class = _get_model_architecture(model_config.hf_config)
# Get the quantization config.
quant_config = None
if model_config.quantization is not None:
if model_class not in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
raise ValueError(
f"Quantization is not supported for {model_class}.")
quant_config = get_quant_config(model_config.quantization,
model_config.model,
model_config.download_dir)
supported_dtypes = quant_config.get_supported_act_dtypes()
if model_config.dtype not in supported_dtypes:
raise ValueError(
f"{model_config.dtype} is not supported for quantization "
f"method {model_config.quantization}. Supported dtypes: "
f"{supported_dtypes}")
with _set_default_torch_dtype(model_config.dtype):
# Create a model instance.
# The weights will be initialized as empty tensors.
model = model_class(model_config.hf_config)
if model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
model = model_class(model_config.hf_config, quant_config)
else:
model = model_class(model_config.hf_config)
if model_config.load_format == "dummy":
model = model.cuda()
# NOTE(woosuk): For accurate performance evaluation, we assign

View File

@ -36,13 +36,15 @@ from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab,
hf_model_weights_iterator)
from vllm.model_executor.layers.quantized_linear import ParallelLinear
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
VocabParallelEmbedding)
from vllm.model_executor.quantization_utils import QuantizationConfig
from vllm.model_executor.weight_utils import (
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -55,18 +57,21 @@ class LlamaMLP(nn.Module):
hidden_size: int,
intermediate_size: int,
hidden_act: str,
):
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.gate_up_proj = ColumnParallelLinear(hidden_size,
2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False)
self.gate_up_proj = ParallelLinear.column(hidden_size,
2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False,
quant_config=quant_config)
self.down_proj = ParallelLinear.row(intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
@ -87,7 +92,8 @@ class LlamaAttention(nn.Module):
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
):
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
@ -103,20 +109,22 @@ class LlamaAttention(nn.Module):
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.qkv_proj = ColumnParallelLinear(
self.qkv_proj = ParallelLinear.column(
hidden_size,
(self.total_num_heads + 2 * self.total_num_kv_heads) *
self.head_dim,
bias=False,
gather_output=False,
perform_initialization=False,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.o_proj = ParallelLinear.row(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
quant_config=quant_config,
)
self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_dim,
@ -144,7 +152,11 @@ class LlamaAttention(nn.Module):
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig):
def __init__(
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
@ -154,11 +166,13 @@ class LlamaDecoderLayer(nn.Module):
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
quant_config=quant_config,
)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
@ -195,7 +209,11 @@ class LlamaDecoderLayer(nn.Module):
class LlamaModel(nn.Module):
def __init__(self, config: LlamaConfig):
def __init__(
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
@ -205,7 +223,8 @@ class LlamaModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding(
vocab_size, config.hidden_size, perform_initialization=False)
self.layers = nn.ModuleList([
LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
LlamaDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -237,16 +256,23 @@ class LlamaModel(nn.Module):
class LlamaForCausalLM(nn.Module):
def __init__(self, config):
def __init__(
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.model = LlamaModel(config)
self.quant_config = quant_config
self.model = LlamaModel(config, quant_config)
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ColumnParallelLinear(config.hidden_size,
vocab_size,
bias=False,
gather_output=False,
perform_initialization=False)
# NOTE: The LM head is not quantized.
self.lm_head = ParallelLinear.column(config.hidden_size,
vocab_size,
bias=False,
gather_output=False,
perform_initialization=False,
quant_config=None)
self.sampler = Sampler(config.vocab_size)
def forward(
@ -263,16 +289,28 @@ class LlamaForCausalLM(nn.Module):
input_metadata)
return next_tokens
_column_parallel_weights = [
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
]
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
_column_parallel_layers = []
_row_parallel_layers = ["o_proj", "down_proj"]
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
if self.quant_config is None:
weight_suffixes = ["weight"]
else:
weight_suffixes = self.quant_config.get_tp_tensor_names()
column_parallel_weights: List[str] = []
for layer in self._column_parallel_layers:
for suffix in weight_suffixes:
column_parallel_weights.append(f"{layer}.{suffix}")
row_parallel_weights: List[str] = []
for layer in self._row_parallel_layers:
for suffix in weight_suffixes:
row_parallel_weights.append(f"{layer}.{suffix}")
tp_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size)
@ -293,11 +331,25 @@ class LlamaForCausalLM(nn.Module):
if "rotary_emb.inv_freq" in name:
continue
is_packed = False
is_transposed = False
if self.quant_config is not None:
is_packed = self.quant_config.is_packed(name)
is_transposed = self.quant_config.is_transposed(name)
if is_transposed:
loaded_weight = loaded_weight.T
is_attention_weight = False
for weight_name, shard_size, offset in attention_weight_specs:
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "qkv_proj")]
if is_transposed:
param = param.T
if is_packed:
shard_size //= self.quant_config.pack_factor
offset //= self.quant_config.pack_factor
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
@ -316,6 +368,9 @@ class LlamaForCausalLM(nn.Module):
if weight_name not in name:
continue
param = state_dict[name.replace(weight_name, "gate_up_proj")]
if is_transposed:
param = param.T
shard_size = param.shape[0] // 2
loaded_weight = loaded_weight[
shard_size * tensor_model_parallel_rank:shard_size *
@ -330,6 +385,8 @@ class LlamaForCausalLM(nn.Module):
continue
param = state_dict[name]
if is_transposed:
param = param.T
if "embed_tokens" in name or "lm_head" in name:
load_padded_tensor_parallel_vocab(param, loaded_weight,
@ -337,6 +394,6 @@ class LlamaForCausalLM(nn.Module):
continue
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,
column_parallel_weights,
row_parallel_weights,
tensor_model_parallel_rank)

View File

@ -4,7 +4,7 @@
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
from typing import Optional
import torch
import torch.nn.functional as F
@ -16,13 +16,11 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size,
)
from .mappings import (
copy_to_tensor_model_parallel_region,
gather_from_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region,
scatter_to_tensor_model_parallel_region,
)
from .random import get_cuda_rng_tracker
from .utils import (
divide,
VocabUtility,
@ -65,59 +63,6 @@ def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
maybe_copy(attribute)
def _initialize_affine_weight_gpu(weight, init_method,
partition_dim, stride=1):
"""Initialize affine weight for model parallel on GPU."""
set_tensor_model_parallel_attributes(tensor=weight,
is_parallel=True,
dim=partition_dim,
stride=stride)
with get_cuda_rng_tracker().fork():
init_method(weight)
def _initialize_affine_weight_cpu(weight, output_size, input_size,
per_partition_size, partition_dim,
init_method, stride=1,
return_master_weight=False,
*, params_dtype=None):
"""Initialize affine weight for model parallel.
Build the master weight on all processes and scatter
the relevant chunk."""
set_tensor_model_parallel_attributes(tensor=weight,
is_parallel=True,
dim=partition_dim,
stride=stride)
if params_dtype is None:
params_dtype = torch.get_default_dtype()
# Initialize master weight
master_weight = torch.empty(output_size, input_size,
dtype=torch.float,
requires_grad=False)
init_method(master_weight)
master_weight = master_weight.to(dtype=params_dtype)
# Split and copy
per_partition_per_stride_size = divide(per_partition_size, stride)
weight_list = torch.split(master_weight, per_partition_per_stride_size,
dim=partition_dim)
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
my_weight_list = weight_list[rank::world_size]
with torch.no_grad():
torch.cat(my_weight_list, dim=partition_dim, out=weight)
if return_master_weight:
return master_weight
return None
class VocabParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
@ -140,6 +85,9 @@ class VocabParallelEmbedding(torch.nn.Module):
use_cpu_initialization: bool=False,
perform_initialization: bool=True):
super(VocabParallelEmbedding, self).__init__()
assert not perform_initialization
assert not use_cpu_initialization
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
@ -162,24 +110,10 @@ class VocabParallelEmbedding(torch.nn.Module):
self.num_embeddings_per_partition = self.vocab_end_index - \
self.vocab_start_index
# Allocate weights and initialize.
if use_cpu_initialization:
self.weight = Parameter(torch.empty(
self.num_embeddings_per_partition, self.embedding_dim,
dtype=params_dtype))
if perform_initialization:
_initialize_affine_weight_cpu(
self.weight, self.num_embeddings, self.embedding_dim,
self.num_embeddings_per_partition, 0, init_method,
params_dtype=params_dtype)
else:
self.weight = Parameter(torch.empty(
self.num_embeddings_per_partition, self.embedding_dim,
device=torch.cuda.current_device(), dtype=params_dtype))
if perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=1)
self.weight = Parameter(torch.empty(
self.num_embeddings_per_partition, self.embedding_dim,
device=torch.cuda.current_device(), dtype=params_dtype))
def forward(self, input_):
if self.tensor_model_parallel_size > 1:
# Build the mask.
@ -239,8 +173,11 @@ class ColumnParallelLinear(torch.nn.Module):
params_dtype=None,
use_cpu_initialization=False,
perform_initialization=True,
quant_config=None,
):
super(ColumnParallelLinear, self).__init__()
assert not perform_initialization
assert not use_cpu_initialization
# Keep input parameters
self.input_size = input_size
@ -250,6 +187,7 @@ class ColumnParallelLinear(torch.nn.Module):
self.world_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, self.world_size)
self.skip_bias_add = skip_bias_add
self.quant_config = quant_config
if params_dtype is None:
params_dtype = torch.get_default_dtype()
@ -257,33 +195,13 @@ class ColumnParallelLinear(torch.nn.Module):
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
# Initialize weight.
if use_cpu_initialization:
self.weight = Parameter(torch.empty(self.output_size_per_partition,
self.input_size,
dtype=params_dtype))
if perform_initialization:
self.master_weight = _initialize_affine_weight_cpu(
self.weight, self.output_size, self.input_size,
self.output_size_per_partition, 0, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test)
else:
self.weight = Parameter(torch.empty(
self.output_size_per_partition, self.input_size,
device=torch.cuda.current_device(), dtype=params_dtype))
if perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=0, stride=stride)
self.create_weights(params_dtype)
if bias:
if use_cpu_initialization:
self.bias = Parameter(torch.empty(
self.output_size_per_partition, dtype=params_dtype))
else:
self.bias = Parameter(torch.empty(
self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype))
self.bias = Parameter(torch.empty(
self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype))
set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
# Always initialize bias to zero.
with torch.no_grad():
@ -291,6 +209,17 @@ class ColumnParallelLinear(torch.nn.Module):
else:
self.register_parameter('bias', None)
def create_weights(self, dtype: torch.dtype) -> None:
self.weight = Parameter(torch.empty(
self.output_size_per_partition, self.input_size,
device=torch.cuda.current_device(), dtype=dtype))
def apply_weights(
self,
x: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
return F.linear(x, self.weight, bias)
def forward(self, input_):
"""Forward of ColumnParallelLinear
@ -306,7 +235,7 @@ class ColumnParallelLinear(torch.nn.Module):
input_parallel = input_
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight, bias)
output_parallel = self.apply_weights(input_parallel, bias)
if self.gather_output:
# All-gather across the partitions.
output = gather_from_tensor_model_parallel_region(output_parallel)
@ -361,8 +290,11 @@ class RowParallelLinear(torch.nn.Module):
use_cpu_initialization=False,
perform_initialization=True,
reduce_results=True,
quant_config=None,
):
super(RowParallelLinear, self).__init__()
assert not perform_initialization
assert not use_cpu_initialization
# Keep input parameters
self.input_size = input_size
@ -376,47 +308,32 @@ class RowParallelLinear(torch.nn.Module):
self.world_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.world_size)
self.skip_bias_add = skip_bias_add
self.quant_config = quant_config
self.create_weights(params_dtype)
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
# Initialize weight.
if use_cpu_initialization:
self.weight = Parameter(torch.empty(self.output_size,
self.input_size_per_partition,
dtype=params_dtype))
if perform_initialization:
self.master_weight = _initialize_affine_weight_cpu(
self.weight, self.output_size, self.input_size,
self.input_size_per_partition, 1, init_method,
stride=stride, return_master_weight=keep_master_weight_for_test,
params_dtype=params_dtype)
else:
self.weight = Parameter(torch.empty(
self.output_size, self.input_size_per_partition,
device=torch.cuda.current_device(), dtype=params_dtype))
if perform_initialization:
_initialize_affine_weight_gpu(self.weight, init_method,
partition_dim=1, stride=stride)
if bias:
if use_cpu_initialization:
self.bias = Parameter(torch.empty(self.output_size,
dtype=params_dtype))
else:
self.bias = Parameter(torch.empty(
self.output_size, device=torch.cuda.current_device(),
dtype=params_dtype))
self.bias = Parameter(torch.empty(
self.output_size, device=torch.cuda.current_device(),
dtype=params_dtype))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
self.weight_t = self.weight.t()
def create_weights(self, dtype: torch.dtype) -> None:
self.weight = Parameter(torch.empty(
self.output_size, self.input_size_per_partition,
device=torch.cuda.current_device(), dtype=dtype))
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight)
def forward(self, input_):
"""Forward of RowParallelLinear
@ -434,7 +351,7 @@ class RowParallelLinear(torch.nn.Module):
else:
input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight)
output_parallel = self.apply_weights(input_parallel)
if self.reduce_results and self.world_size > 1:
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
else:

View File

@ -0,0 +1,20 @@
from typing import Type
from vllm.model_executor.quantization_utils.awq import AWQConfig
from vllm.model_executor.quantization_utils.base import QuantizationConfig
_QUANTIZATION_REGISTRY = {
"awq": AWQConfig,
}
def get_quant_class(quantization: str) -> Type[QuantizationConfig]:
if quantization not in _QUANTIZATION_REGISTRY:
raise ValueError(f"Invalid quantization method: {quantization}")
return _QUANTIZATION_REGISTRY[quantization]
__all__ = [
"QuantizationConfig",
"get_quant_class",
]

View File

@ -0,0 +1,67 @@
from typing import Any, Dict, List
import torch
from vllm.model_executor.quantization_utils.base import QuantizationConfig
class AWQConfig(QuantizationConfig):
"""Config class for AWQ.
Reference: https://arxiv.org/abs/2306.00978
"""
def __init__(
self,
weight_bits: int,
group_size: int,
zero_point: bool,
) -> None:
self.weight_bits = weight_bits
self.group_size = group_size
self.zero_point = zero_point
if self.weight_bits != 4:
raise ValueError(
"Currently, only 4-bit weight quantization is supported for "
f"AWQ, but got {self.weight_bits} bits.")
self.pack_factor = 32 // self.weight_bits
def __repr__(self) -> str:
return (f"AWQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"zero_point={self.zero_point})")
@classmethod
def get_name(cls) -> str:
return "awq"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]
@classmethod
def get_config_filenames(cls) -> List[str]:
return [
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
"quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq # pylint: disable=line-too-long
]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"])
return cls(weight_bits, group_size, zero_point)
@classmethod
def get_packed_tensor_names(cls) -> List[str]:
return ["qweight", "qzeros"]
@classmethod
def get_transposed_tensor_names(cls) -> List[str]:
return ["qweight", "qzeros", "scales"]
@classmethod
def get_tp_tensor_names(cls) -> List[str]:
return ["qweight", "qzeros", "scales"]

View File

@ -0,0 +1,65 @@
from typing import Any, Dict, List
import torch
class QuantizationConfig:
@classmethod
def get_name(cls) -> str:
"""Name of the quantization method."""
raise NotImplementedError
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
"""List of supported activation dtypes."""
raise NotImplementedError
@classmethod
def get_config_filenames(cls) -> List[str]:
"""List of filenames to search for in the model directory."""
raise NotImplementedError
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
"""Create a config class from the model's quantization config."""
raise NotImplementedError
@staticmethod
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
"""Get a value from the model's quantization config."""
for key in keys:
if key in config:
return config[key]
raise ValueError(f"Cannot find any of {keys} in the model's "
"quantization config.")
@classmethod
def get_packed_tensor_names(cls) -> List[str]:
raise NotImplementedError
@classmethod
def is_packed(cls, tensor_name: str) -> bool:
"""Returns True if a tensor is packed.
A tensor is considered packed if each element in the tensor is a
packed representation of multiple elements in the original tensor.
For example, an INT32 element in the tensor may represent 8 INT4
elements in the original tensor.
"""
return any(tag in tensor_name for tag in cls.get_packed_tensor_names())
@classmethod
def get_transposed_tensor_names(cls) -> List[str]:
raise NotImplementedError
@classmethod
def is_transposed(cls, tensor_name: str) -> bool:
"""Returns True if a tensor is transposed relative to nn.Linear.weight.
"""
return any(tag in tensor_name
for tag in cls.get_transposed_tensor_names())
@classmethod
def get_tp_tensor_names(cls) -> List[str]:
raise NotImplementedError

View File

@ -4,7 +4,7 @@ import glob
import json
import os
from collections import defaultdict
from typing import Iterator, List, Optional, Tuple, Any
from typing import Any, Iterator, List, Optional, Tuple
from huggingface_hub import snapshot_download
from safetensors.torch import load_file, save_file, safe_open
@ -13,6 +13,8 @@ import torch
from tqdm.auto import tqdm
from vllm.logger import init_logger
from vllm.model_executor.quantization_utils import get_quant_class
from vllm.model_executor.quantization_utils.base import QuantizationConfig
logger = init_logger(__name__)
@ -44,7 +46,7 @@ def _shared_pointers(tensors):
def convert_bin_to_safetensor_file(
pt_filename: str,
sf_filename: str,
):
) -> None:
loaded = torch.load(pt_filename, map_location="cpu")
if "state_dict" in loaded:
loaded = loaded["state_dict"]
@ -78,16 +80,55 @@ def convert_bin_to_safetensor_file(
raise RuntimeError(f"The output tensors do not match for key {k}")
# TODO(woosuk): Move this to other place.
def get_quant_config(
quantization: str,
model_name_or_path: str,
cache_dir: Optional[str] = None,
) -> QuantizationConfig:
is_local = os.path.isdir(model_name_or_path)
if not is_local:
# Download the config files.
with get_lock(model_name_or_path, cache_dir):
hf_folder = snapshot_download(model_name_or_path,
allow_patterns="*.json",
cache_dir=cache_dir,
tqdm_class=Disabledtqdm)
else:
hf_folder = model_name_or_path
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
quant_cls = get_quant_class(quantization)
quant_config_files = [
f for f in config_files if any(
f.endswith(x) for x in quant_cls.get_config_filenames())
]
if len(quant_config_files) == 0:
raise ValueError(f"Cannot find the config file for {quantization}")
if len(quant_config_files) > 1:
raise ValueError(f"Found multiple config files for {quantization}: "
f"{quant_config_files}")
quant_config_file = quant_config_files[0]
with open(quant_config_file, "r") as f:
config = json.load(f)
return quant_cls.from_config(config)
def prepare_hf_model_weights(
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_safetensors: bool = False,
fall_back_to_pt: bool = True,
revision: Optional[str] = None,
):
) -> Tuple[str, List[str], bool]:
# Download model weights from huggingface.
is_local = os.path.isdir(model_name_or_path)
allow_patterns = "*.safetensors" if use_safetensors else "*.bin"
if use_safetensors:
allow_patterns = ["*.safetensors"]
else:
# Some quantized models use .pt files for storing the weights.
allow_patterns = ["*.bin", "*.pt"]
if not is_local:
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
@ -99,7 +140,9 @@ def prepare_hf_model_weights(
revision=revision)
else:
hf_folder = model_name_or_path
hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns))
hf_weights_files: List[str] = []
for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
if not use_safetensors:
hf_weights_files = [
x for x in hf_weights_files if not x.endswith("training_args.bin")