Implement AWQ quantization support for LLaMA (#1032)
Co-authored-by: Robert Irvine <robert@seamlessml.com> Co-authored-by: root <rirv938@gmail.com> Co-authored-by: Casper <casperbh.96@gmail.com> Co-authored-by: julian-q <julianhquevedo@gmail.com>
This commit is contained in:
parent
b9fe4616f9
commit
e3e79e9e8a
4
.gitignore
vendored
4
.gitignore
vendored
@ -173,3 +173,7 @@ cython_debug/
|
||||
|
||||
# Sphinx documentation
|
||||
_build/
|
||||
|
||||
# vim swap files
|
||||
*.swo
|
||||
*.swp
|
||||
|
@ -18,6 +18,7 @@ def main(args: argparse.Namespace):
|
||||
llm = LLM(
|
||||
model=args.model,
|
||||
tokenizer=args.tokenizer,
|
||||
quantization=args.quantization,
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
max_num_seqs=args.batch_size,
|
||||
max_num_batched_tokens=args.batch_size * args.input_len,
|
||||
@ -63,19 +64,28 @@ def main(args: argparse.Namespace):
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Benchmark the latency of processing a single batch of '
|
||||
'requests till completion.')
|
||||
'requests till completion.')
|
||||
parser.add_argument('--model', type=str, default='facebook/opt-125m')
|
||||
parser.add_argument('--tokenizer', type=str, default=None)
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
choices=['awq', None],
|
||||
default=None)
|
||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
||||
parser.add_argument('--input-len', type=int, default=32)
|
||||
parser.add_argument('--output-len', type=int, default=128)
|
||||
parser.add_argument('--batch-size', type=int, default=8)
|
||||
parser.add_argument('--n', type=int, default=1,
|
||||
parser.add_argument('--n',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Number of generated sequences per prompt.')
|
||||
parser.add_argument('--use-beam-search', action='store_true')
|
||||
parser.add_argument('--num-iters', type=int, default=3,
|
||||
parser.add_argument('--num-iters',
|
||||
type=int,
|
||||
default=3,
|
||||
help='Number of iterations to run.')
|
||||
parser.add_argument('--trust-remote-code', action='store_true',
|
||||
parser.add_argument('--trust-remote-code',
|
||||
action='store_true',
|
||||
help='trust remote code from huggingface')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
@ -3,7 +3,7 @@ import argparse
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||
@ -22,15 +22,10 @@ def sample_requests(
|
||||
with open(dataset_path) as f:
|
||||
dataset = json.load(f)
|
||||
# Filter out the conversations with less than 2 turns.
|
||||
dataset = [
|
||||
data for data in dataset
|
||||
if len(data["conversations"]) >= 2
|
||||
]
|
||||
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||
# Only keep the first two turns of each conversation.
|
||||
dataset = [
|
||||
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
||||
for data in dataset
|
||||
]
|
||||
dataset = [(data["conversations"][0]["value"],
|
||||
data["conversations"][1]["value"]) for data in dataset]
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompts = [prompt for prompt, _ in dataset]
|
||||
@ -63,6 +58,7 @@ def run_vllm(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
model: str,
|
||||
tokenizer: str,
|
||||
quantization: Optional[str],
|
||||
tensor_parallel_size: int,
|
||||
seed: int,
|
||||
n: int,
|
||||
@ -72,6 +68,7 @@ def run_vllm(
|
||||
llm = LLM(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
quantization=quantization,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
seed=seed,
|
||||
trust_remote_code=trust_remote_code,
|
||||
@ -111,8 +108,8 @@ def run_hf(
|
||||
trust_remote_code: bool,
|
||||
) -> float:
|
||||
assert not use_beam_search
|
||||
llm = AutoModelForCausalLM.from_pretrained(model,
|
||||
torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
|
||||
llm = AutoModelForCausalLM.from_pretrained(
|
||||
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
|
||||
if llm.config.model_type == "llama":
|
||||
# To enable padding in the HF backend.
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
@ -132,13 +129,14 @@ def run_hf(
|
||||
if len(batch) < max_batch_size and i != len(requests) - 1:
|
||||
# Check if we can add more requests to the batch.
|
||||
_, next_prompt_len, next_output_len = requests[i + 1]
|
||||
if (max(max_prompt_len, next_prompt_len) + max(
|
||||
max_output_len, next_output_len)) <= 2048:
|
||||
if (max(max_prompt_len, next_prompt_len) +
|
||||
max(max_output_len, next_output_len)) <= 2048:
|
||||
# We can add more requests to the batch.
|
||||
continue
|
||||
|
||||
# Generate the sequences.
|
||||
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
|
||||
input_ids = tokenizer(batch, return_tensors="pt",
|
||||
padding=True).input_ids
|
||||
llm_outputs = llm.generate(
|
||||
input_ids=input_ids.cuda(),
|
||||
do_sample=not use_beam_search,
|
||||
@ -165,44 +163,58 @@ def main(args: argparse.Namespace):
|
||||
random.seed(args.seed)
|
||||
|
||||
# Sample the requests.
|
||||
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
|
||||
tokenizer = get_tokenizer(args.tokenizer,
|
||||
trust_remote_code=args.trust_remote_code)
|
||||
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
||||
|
||||
if args.backend == "vllm":
|
||||
elapsed_time = run_vllm(
|
||||
requests, args.model, args.tokenizer, args.tensor_parallel_size,
|
||||
args.seed, args.n, args.use_beam_search, args.trust_remote_code)
|
||||
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
|
||||
args.quantization, args.tensor_parallel_size,
|
||||
args.seed, args.n, args.use_beam_search,
|
||||
args.trust_remote_code)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
elapsed_time = run_hf(
|
||||
requests, args.model, tokenizer, args.n, args.use_beam_search,
|
||||
args.hf_max_batch_size, args.trust_remote_code)
|
||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||
args.use_beam_search, args.hf_max_batch_size,
|
||||
args.trust_remote_code)
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {args.backend}")
|
||||
total_num_tokens = sum(
|
||||
prompt_len + output_len
|
||||
for _, prompt_len, output_len in requests
|
||||
)
|
||||
total_num_tokens = sum(prompt_len + output_len
|
||||
for _, prompt_len, output_len in requests)
|
||||
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
|
||||
parser.add_argument("--backend", type=str, choices=["vllm", "hf"],
|
||||
parser.add_argument("--backend",
|
||||
type=str,
|
||||
choices=["vllm", "hf"],
|
||||
default="vllm")
|
||||
parser.add_argument("--dataset", type=str, required=True,
|
||||
parser.add_argument("--dataset",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the dataset.")
|
||||
parser.add_argument("--model", type=str, default="facebook/opt-125m")
|
||||
parser.add_argument("--tokenizer", type=str, default=None)
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
choices=['awq', None],
|
||||
default=None)
|
||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
||||
parser.add_argument("--n", type=int, default=1,
|
||||
parser.add_argument("--n",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of generated sequences per prompt.")
|
||||
parser.add_argument("--use-beam-search", action="store_true")
|
||||
parser.add_argument("--num-prompts", type=int, default=1000,
|
||||
parser.add_argument("--num-prompts",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of prompts to process.")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--hf-max-batch-size", type=int, default=None,
|
||||
parser.add_argument("--hf-max-batch-size",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum batch size for HF backend.")
|
||||
parser.add_argument('--trust-remote-code',
|
||||
action='store_true',
|
||||
@ -215,6 +227,8 @@ if __name__ == "__main__":
|
||||
elif args.backend == "hf":
|
||||
if args.hf_max_batch_size is None:
|
||||
raise ValueError("HF max batch size is required for HF backend.")
|
||||
if args.quantization is not None:
|
||||
raise ValueError("Quantization is only for vLLM backend.")
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
|
||||
|
15
csrc/quantization.cpp
Normal file
15
csrc/quantization.cpp
Normal file
@ -0,0 +1,15 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
torch::Tensor awq_gemm(
|
||||
torch::Tensor _in_feats,
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"awq_gemm",
|
||||
&awq_gemm,
|
||||
"Quantized GEMM for AWQ");
|
||||
}
|
79
csrc/quantization/awq/dequantize.cuh
Normal file
79
csrc/quantization/awq/dequantize.cuh
Normal file
@ -0,0 +1,79 @@
|
||||
/*
|
||||
Adapted from https://github.com/mit-han-lab/llm-awq
|
||||
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
@article{lin2023awq,
|
||||
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
|
||||
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
|
||||
journal={arXiv},
|
||||
year={2023}
|
||||
}
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
|
||||
{
|
||||
uint4 result;
|
||||
|
||||
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
||||
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
|
||||
|
||||
// First, we extract the i4s and construct an intermediate fp16 number.
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
||||
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
|
||||
static constexpr uint32_t TOP_MASK = 0x00f000f0;
|
||||
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||
|
||||
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
|
||||
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
|
||||
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
|
||||
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
|
||||
|
||||
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
|
||||
// immediately before required.
|
||||
const uint32_t top_i4s = i4s >> 8;
|
||||
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[0])
|
||||
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[1])
|
||||
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[2])
|
||||
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
|
||||
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||
: "=r"(h[3])
|
||||
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||
|
||||
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
|
||||
// half2 ctor. In this case, I chose performance reliability over code readability.
|
||||
|
||||
// This is the half2 {1032, 1032} represented as an integer.
|
||||
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
|
||||
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
|
||||
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
|
||||
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
|
||||
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
|
||||
// This is the half2 {-72, -72} represented as an integer.
|
||||
// static constexpr uint32_t NEG_72 = 0xd480d480;
|
||||
// Haotian: Let's use {-64, -64}.
|
||||
static constexpr uint32_t NEG_64 = 0xd400d400;
|
||||
|
||||
// Finally, we construct the output numbers.
|
||||
// Convert elt_01
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
|
||||
// Convert elt_23
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||
// Convert elt_45
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
|
||||
// Convert elt_67
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
477
csrc/quantization/awq/gemm_kernels.cu
Normal file
477
csrc/quantization/awq/gemm_kernels.cu
Normal file
@ -0,0 +1,477 @@
|
||||
/*
|
||||
Adapted from https://github.com/mit-han-lab/llm-awq
|
||||
@article{lin2023awq,
|
||||
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
|
||||
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
|
||||
journal={arXiv},
|
||||
year={2023}
|
||||
}
|
||||
*/
|
||||
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "dequantize.cuh"
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
// Pack two half values.
|
||||
static inline __device__ __host__ unsigned
|
||||
__pack_half2(const half x, const half y) {
|
||||
unsigned v0 = *((unsigned short *)&x);
|
||||
unsigned v1 = *((unsigned short *)&y);
|
||||
return (v1 << 16) | v0;
|
||||
}
|
||||
|
||||
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
|
||||
{
|
||||
static constexpr uint32_t ZERO = 0x0;
|
||||
float C_warp[32];
|
||||
__shared__ half A_shared[16 * (32 + 8)];
|
||||
__shared__ half B_shared[32 * (128 + 8)];
|
||||
|
||||
__shared__ half scaling_factors_shared[128];
|
||||
__shared__ half zeros_shared[128];
|
||||
|
||||
int j_factors1 = ((OC + 128 - 1) / 128);
|
||||
int blockIdx_x = 0;
|
||||
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
|
||||
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
|
||||
|
||||
half A_shared_warp[8];
|
||||
half B_shared_warp[32];
|
||||
for (int j_0_4_init = 0; j_0_4_init < 4; ++j_0_4_init) {
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
C_warp[(j_0_4_init * 8) + i] = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr int row_stride_warp = 32 * 8 / 32;
|
||||
static constexpr int row_stride = 2 * 32 * 8 / 128;
|
||||
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 128;
|
||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
|
||||
// bool wb_C_flag = (threadIdx.x / 4) < M;
|
||||
|
||||
half* A_ptr = A
|
||||
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
|
||||
+ (((int)threadIdx.x) % (32 / 8)) * 8;
|
||||
|
||||
int* B_ptr = B
|
||||
+ ((int)threadIdx.y) * (OC / 8) * 2
|
||||
+ (((int)threadIdx.x) / (128 / 8)) * (OC / 8)
|
||||
+ (((int)blockIdx_y) % j_factors1) * (128 / 8)
|
||||
+ (((int)threadIdx.x) % (128 / 8)) * 1;
|
||||
// Why * 1 in the above line?
|
||||
|
||||
half* A_shared_ptr = A_shared
|
||||
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
|
||||
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
|
||||
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
|
||||
|
||||
half* B_shared_ptr = B_shared
|
||||
+ ((int)threadIdx.y) * (row_stride / 2) * (128 + 8)
|
||||
+ (((int)threadIdx.x) / (128 / 8)) * (128 + 8)
|
||||
+ (((int)threadIdx.x) % (128 / 8)) * 8;
|
||||
|
||||
int* zeros_ptr = zeros
|
||||
+ (((int)blockIdx_y) % j_factors1) * (128 / 8)
|
||||
+ ((int)threadIdx.x) % (128 / 8);
|
||||
|
||||
half* scaling_factors_ptr = scaling_factors
|
||||
+ (((int)blockIdx_y) % j_factors1) * (128)
|
||||
+ (((int)threadIdx.x) % (128 / 8)) * 8;
|
||||
|
||||
half* C_ptr = C
|
||||
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim
|
||||
+ (((int)blockIdx_y) % j_factors1) * 128
|
||||
+ ((int)threadIdx.y) * 64
|
||||
+ (((int)threadIdx.x) % 4) * 2;
|
||||
|
||||
// preload s.f. and zeros
|
||||
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
|
||||
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
|
||||
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
|
||||
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
|
||||
__syncthreads();
|
||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||
if (ld_A_flag)
|
||||
{
|
||||
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
|
||||
}
|
||||
else
|
||||
{
|
||||
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
|
||||
}
|
||||
|
||||
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
|
||||
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
|
||||
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
|
||||
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
|
||||
/*
|
||||
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
|
||||
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
|
||||
}
|
||||
*/
|
||||
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
|
||||
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
|
||||
|
||||
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) {
|
||||
|
||||
// B: 32 x 136 (128+8) float16
|
||||
// each warp: 32 x 4
|
||||
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
|
||||
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
|
||||
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
|
||||
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
|
||||
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||
|
||||
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||
// - zero and * scale
|
||||
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||
/*
|
||||
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
|
||||
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
|
||||
}
|
||||
*/
|
||||
|
||||
// write back
|
||||
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (128 + 8)) = B_loaded_fp16;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
|
||||
{
|
||||
unsigned int addr;
|
||||
__asm__ __volatile__(
|
||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||
: "=r"(addr)
|
||||
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||
);
|
||||
|
||||
|
||||
__asm__ __volatile__(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
|
||||
: "r"(addr)
|
||||
);
|
||||
}
|
||||
|
||||
for (int ax1_0 = 0; ax1_0 < 4; ++ax1_0) {
|
||||
{
|
||||
unsigned int addr;
|
||||
__asm__ __volatile__(
|
||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||
: "=r"(addr)
|
||||
: "l"((void *)((&(B_shared[(((k_0_1 * 2176) + (((int)threadIdx.y) * 64)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 136) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||
);
|
||||
__asm__ __volatile__(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
|
||||
: "r"(addr)
|
||||
);
|
||||
}
|
||||
}
|
||||
for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) {
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Shang: Hoist loop invariance.
|
||||
for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
|
||||
for (int local_id = 0; local_id < 8; ++local_id) {
|
||||
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
|
||||
if (row_offset < M)
|
||||
{
|
||||
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
|
||||
{
|
||||
static constexpr uint32_t ZERO = 0x0;
|
||||
float C_warp[32];
|
||||
__shared__ half A_shared[16 * (32 + 8)];
|
||||
__shared__ half B_shared[32 * (64 + 8)];
|
||||
|
||||
__shared__ half scaling_factors_shared[64];
|
||||
__shared__ half zeros_shared[64];
|
||||
|
||||
int j_factors1 = ((OC + 64 - 1) / 64);
|
||||
|
||||
int blockIdx_x = 0;
|
||||
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
|
||||
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
|
||||
|
||||
half A_shared_warp[8];
|
||||
half B_shared_warp[16];
|
||||
for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) {
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
C_warp[(j_0_4_init * 8) + i] = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr int row_stride_warp = 32 * 8 / 32;
|
||||
static constexpr int row_stride = 2 * 32 * 8 / 64;
|
||||
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64;
|
||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
|
||||
// bool wb_C_flag = (threadIdx.x / 4) < M;
|
||||
|
||||
half* A_ptr = A
|
||||
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
|
||||
+ (((int)threadIdx.x) % (32 / 8)) * 8;
|
||||
|
||||
int* B_ptr = B
|
||||
+ ((int)threadIdx.y) * (OC / 8) * 4
|
||||
+ (((int)threadIdx.x) / (64 / 8)) * (OC / 8)
|
||||
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
|
||||
+ (((int)threadIdx.x) % (64 / 8)) * 1;
|
||||
// Why * 1 in the above line?
|
||||
|
||||
half* A_shared_ptr = A_shared
|
||||
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
|
||||
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
|
||||
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
|
||||
|
||||
half* B_shared_ptr = B_shared
|
||||
+ ((int)threadIdx.y) * (row_stride / 2) * (64 + 8)
|
||||
+ (((int)threadIdx.x) / (64 / 8)) * (64 + 8)
|
||||
+ (((int)threadIdx.x) % (64 / 8)) * 8;
|
||||
|
||||
int* zeros_ptr = zeros
|
||||
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
|
||||
+ ((int)threadIdx.x) % (64 / 8);
|
||||
|
||||
half* scaling_factors_ptr = scaling_factors
|
||||
+ (((int)blockIdx_y) % j_factors1) * (64)
|
||||
+ (((int)threadIdx.x) % (64 / 8)) * 8;
|
||||
|
||||
half* C_ptr = C
|
||||
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim
|
||||
+ (((int)blockIdx_y) % j_factors1) * 64
|
||||
+ ((int)threadIdx.y) * 32
|
||||
+ (((int)threadIdx.x) % 4) * 2;
|
||||
|
||||
// preload s.f. and zeros
|
||||
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
|
||||
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
|
||||
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
|
||||
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
|
||||
__syncthreads();
|
||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||
if (ld_A_flag)
|
||||
{
|
||||
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
|
||||
}
|
||||
else
|
||||
{
|
||||
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
|
||||
}
|
||||
|
||||
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
|
||||
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
|
||||
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
|
||||
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
|
||||
/*
|
||||
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
|
||||
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
|
||||
}
|
||||
*/
|
||||
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
|
||||
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
|
||||
|
||||
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) {
|
||||
|
||||
// B: 32 x 136 (128+8) float16
|
||||
// each warp: 32 x 4
|
||||
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
|
||||
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
|
||||
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
|
||||
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
|
||||
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||
|
||||
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||
// - zero and * scale
|
||||
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||
/*
|
||||
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
|
||||
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
|
||||
}
|
||||
*/
|
||||
|
||||
// write back
|
||||
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_fp16;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1)
|
||||
{
|
||||
{
|
||||
unsigned int addr;
|
||||
__asm__ __volatile__(
|
||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||
: "=r"(addr)
|
||||
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||
);
|
||||
__asm__ __volatile__(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
|
||||
: "r"(addr)
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0)
|
||||
{
|
||||
{
|
||||
unsigned int addr;
|
||||
__asm__ __volatile__(
|
||||
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||
: "=r"(addr)
|
||||
: "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||
);
|
||||
__asm__ __volatile__(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
|
||||
: "r"(addr)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4)
|
||||
{
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||
}
|
||||
|
||||
{
|
||||
__asm__ __volatile__(
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Shang: Hoist loop invariance.
|
||||
for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) {
|
||||
for (int local_id = 0; local_id < 8; ++local_id) {
|
||||
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
|
||||
if (row_offset < M)
|
||||
{
|
||||
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// in_feats: M, IC [float16]
|
||||
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
|
||||
// scaling_factors: IC // G, OC [float16]
|
||||
// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
|
||||
// assume that batch_size < 16 for now
|
||||
|
||||
torch::Tensor awq_gemm(
|
||||
torch::Tensor _in_feats,
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters)
|
||||
{
|
||||
int num_in_feats = _in_feats.size(0);
|
||||
int num_in_channels = _in_feats.size(1);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
|
||||
|
||||
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
|
||||
at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
|
||||
int num_out_feats = _out_feats.size(-2);
|
||||
int num_out_channels = _out_feats.size(-1);
|
||||
|
||||
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
|
||||
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
||||
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
|
||||
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
||||
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
||||
int group_size = num_in_channels / _scaling_factors.size(0);
|
||||
|
||||
if (num_out_channels % 64 != 0)
|
||||
throw std::invalid_argument("OC is not multiple of cta_N = 64");
|
||||
if (num_out_channels % 8 != 0)
|
||||
throw std::invalid_argument("OC is not multiple of pack_num = 8");
|
||||
if (group_size % 32 != 0)
|
||||
throw std::invalid_argument("Group size should be a multiple of 32");
|
||||
if (num_out_channels % group_size != 0)
|
||||
throw std::invalid_argument("OC is not multiple of Group size");
|
||||
|
||||
if (num_out_channels % 128 == 0)
|
||||
{
|
||||
int j_factors1 = num_out_channels / 128 / 1;
|
||||
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
||||
// threadIdx.x: 32
|
||||
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||
dim3 threads_per_block(32, 2);
|
||||
gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
|
||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
|
||||
}
|
||||
else if (num_out_channels % 64 == 0)
|
||||
{
|
||||
int j_factors1 = num_out_channels / 64 / 1;
|
||||
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
||||
|
||||
// threadIdx.x: 32
|
||||
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||
dim3 threads_per_block(32, 2);
|
||||
gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>(
|
||||
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
|
||||
}
|
||||
return _out_feats.sum(0);
|
||||
}
|
14
setup.py
14
setup.py
@ -146,6 +146,20 @@ activation_extension = CUDAExtension(
|
||||
)
|
||||
ext_modules.append(activation_extension)
|
||||
|
||||
# Quantization kernels.
|
||||
quantization_extension = CUDAExtension(
|
||||
name="vllm.quantization_ops",
|
||||
sources=[
|
||||
"csrc/quantization.cpp",
|
||||
"csrc/quantization/awq/gemm_kernels.cu",
|
||||
],
|
||||
extra_compile_args={
|
||||
"cxx": CXX_FLAGS,
|
||||
"nvcc": NVCC_FLAGS,
|
||||
},
|
||||
)
|
||||
ext_modules.append(quantization_extension)
|
||||
|
||||
|
||||
def get_path(*filepath) -> str:
|
||||
return os.path.join(ROOT_DIR, *filepath)
|
||||
|
@ -43,6 +43,8 @@ class ModelConfig:
|
||||
version.
|
||||
max_model_len: Maximum length of a sequence (including prompt and
|
||||
output). If None, will be derived from the model.
|
||||
quantization: Quantization method that was used to quantize the model
|
||||
weights. If None, we assume the model weights are not quantized.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -57,6 +59,7 @@ class ModelConfig:
|
||||
seed: int,
|
||||
revision: Optional[str],
|
||||
max_model_len: Optional[int] = None,
|
||||
quantization: Optional[str] = None,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
@ -66,11 +69,13 @@ class ModelConfig:
|
||||
self.load_format = load_format
|
||||
self.seed = seed
|
||||
self.revision = revision
|
||||
self.quantization = quantization
|
||||
|
||||
self.hf_config = get_config(model, trust_remote_code, revision)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
||||
self._verify_load_format()
|
||||
self._verify_tokenizer_mode()
|
||||
self._verify_quantization()
|
||||
self.max_model_len = None
|
||||
if max_model_len is not None:
|
||||
derived_max_model_len = self.get_max_model_len()
|
||||
@ -100,6 +105,17 @@ class ModelConfig:
|
||||
"either 'auto' or 'slow'.")
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
|
||||
def _verify_quantization(self) -> None:
|
||||
supported_quantization = ["awq"]
|
||||
if self.quantization is None:
|
||||
return
|
||||
quantization = self.quantization.lower()
|
||||
if quantization not in supported_quantization:
|
||||
raise ValueError(
|
||||
f"Unknown quantization: {self.quantization}. Must be one of "
|
||||
f"{supported_quantization}.")
|
||||
self.quantization = quantization
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
parallel_config: "ParallelConfig",
|
||||
|
@ -29,6 +29,7 @@ class EngineArgs:
|
||||
max_num_seqs: int = 256
|
||||
disable_log_stats: bool = False
|
||||
revision: Optional[str] = None
|
||||
quantization: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer is None:
|
||||
@ -88,7 +89,6 @@ class EngineArgs:
|
||||
'a numpy cache to speed up the loading. '
|
||||
'"dummy" will initialize the weights with random values, '
|
||||
'which is mainly for profiling.')
|
||||
# TODO(woosuk): Support FP32.
|
||||
parser.add_argument(
|
||||
'--dtype',
|
||||
type=str,
|
||||
@ -150,6 +150,13 @@ class EngineArgs:
|
||||
parser.add_argument('--disable-log-stats',
|
||||
action='store_true',
|
||||
help='disable logging statistics')
|
||||
# Quantization settings.
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
type=str,
|
||||
choices=['awq', None],
|
||||
default=None,
|
||||
help='Method used to quantize the weights')
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
@ -163,12 +170,11 @@ class EngineArgs:
|
||||
def create_engine_configs(
|
||||
self,
|
||||
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
|
||||
# Initialize the configs.
|
||||
model_config = ModelConfig(self.model, self.tokenizer,
|
||||
self.tokenizer_mode, self.trust_remote_code,
|
||||
self.download_dir, self.load_format,
|
||||
self.dtype, self.seed, self.revision,
|
||||
self.max_model_len)
|
||||
self.max_model_len, self.quantization)
|
||||
cache_config = CacheConfig(self.block_size,
|
||||
self.gpu_memory_utilization,
|
||||
self.swap_space)
|
||||
|
@ -80,6 +80,7 @@ class LLMEngine:
|
||||
f"download_dir={model_config.download_dir!r}, "
|
||||
f"load_format={model_config.load_format}, "
|
||||
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
|
||||
f"quantization={model_config.quantization}, "
|
||||
f"seed={model_config.seed})")
|
||||
# TODO(woosuk): Print more configs in debug mode.
|
||||
|
||||
|
37
vllm/model_executor/layers/quantized_linear/__init__.py
Normal file
37
vllm/model_executor/layers/quantized_linear/__init__.py
Normal file
@ -0,0 +1,37 @@
|
||||
from vllm.model_executor.layers.quantized_linear.awq import (
|
||||
AWQColumnParallelLinear, AWQRowParallelLinear)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
ColumnParallelLinear, RowParallelLinear)
|
||||
|
||||
_QUANTIZED_LINEAR_REGISTRY = {
|
||||
"awq": (AWQColumnParallelLinear, AWQRowParallelLinear),
|
||||
}
|
||||
|
||||
|
||||
class ParallelLinear:
|
||||
|
||||
@classmethod
|
||||
def column(cls, *args, **kwargs) -> ColumnParallelLinear:
|
||||
quant_config = kwargs.get("quant_config", None)
|
||||
if quant_config is None:
|
||||
return ColumnParallelLinear(*args, **kwargs)
|
||||
|
||||
name = quant_config.get_name()
|
||||
if name not in _QUANTIZED_LINEAR_REGISTRY:
|
||||
raise ValueError(f"No quantized linear is found for {name}")
|
||||
|
||||
quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][0]
|
||||
return quant_linear_cls(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def row(cls, *args, **kwargs) -> RowParallelLinear:
|
||||
quant_config = kwargs.get("quant_config", None)
|
||||
if quant_config is None:
|
||||
return RowParallelLinear(*args, **kwargs)
|
||||
|
||||
name = quant_config.get_name()
|
||||
if name not in _QUANTIZED_LINEAR_REGISTRY:
|
||||
raise ValueError(f"No quantized linear is found for {name}")
|
||||
|
||||
quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][1]
|
||||
return quant_linear_cls(*args, **kwargs)
|
102
vllm/model_executor/layers/quantized_linear/awq.py
Normal file
102
vllm/model_executor/layers/quantized_linear/awq.py
Normal file
@ -0,0 +1,102 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import quantization_ops
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel.layers import (
|
||||
ColumnParallelLinear, RowParallelLinear)
|
||||
|
||||
|
||||
class AWQColumnParallelLinear(ColumnParallelLinear):
|
||||
|
||||
def create_weights(self, dtype: torch.dtype) -> None:
|
||||
assert self.input_size % self.quant_config.weight_bits == 0
|
||||
assert (self.output_size_per_partition %
|
||||
self.quant_config.pack_factor == 0)
|
||||
self.qweight = Parameter(
|
||||
torch.empty(
|
||||
self.input_size,
|
||||
self.output_size_per_partition //
|
||||
self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.qzeros = Parameter(
|
||||
torch.empty(
|
||||
self.input_size // self.quant_config.group_size,
|
||||
self.output_size_per_partition //
|
||||
self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.scales = Parameter(
|
||||
torch.empty(
|
||||
self.input_size // self.quant_config.group_size,
|
||||
self.output_size_per_partition,
|
||||
device="cuda",
|
||||
dtype=dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
pack_factor = self.quant_config.pack_factor
|
||||
out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor)
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
|
||||
self.qzeros, pack_factor)
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out.reshape(out_shape)
|
||||
|
||||
|
||||
class AWQRowParallelLinear(RowParallelLinear):
|
||||
|
||||
def create_weights(self, dtype: torch.dtype) -> None:
|
||||
assert (self.input_size_per_partition %
|
||||
self.quant_config.weight_bits == 0)
|
||||
assert self.output_size % self.quant_config.pack_factor == 0
|
||||
self.qweight = Parameter(
|
||||
torch.empty(
|
||||
self.input_size_per_partition,
|
||||
self.output_size // self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.qzeros = Parameter(
|
||||
torch.empty(
|
||||
self.input_size_per_partition // self.quant_config.group_size,
|
||||
self.output_size // self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.scales = Parameter(
|
||||
torch.empty(
|
||||
self.input_size_per_partition // self.quant_config.group_size,
|
||||
self.output_size,
|
||||
device="cuda",
|
||||
dtype=dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||
pack_factor = self.quant_config.pack_factor
|
||||
out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor)
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
|
||||
self.qzeros, pack_factor)
|
||||
return out.reshape(out_shape)
|
@ -8,7 +8,8 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.model_executor.models import * # pylint: disable=wildcard-import
|
||||
from vllm.model_executor.weight_utils import initialize_dummy_weights
|
||||
from vllm.model_executor.weight_utils import (get_quant_config,
|
||||
initialize_dummy_weights)
|
||||
|
||||
# TODO(woosuk): Lazy-load the model classes.
|
||||
_MODEL_REGISTRY = {
|
||||
@ -30,6 +31,11 @@ _MODEL_REGISTRY = {
|
||||
"RWForCausalLM": FalconForCausalLM,
|
||||
}
|
||||
|
||||
# FIXME(woosuk): Remove this once all models support quantization.
|
||||
_MODEL_CLASSES_SUPPORT_QUANTIZATION = [
|
||||
LlamaForCausalLM,
|
||||
]
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _set_default_torch_dtype(dtype: torch.dtype):
|
||||
@ -52,10 +58,30 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
||||
|
||||
def get_model(model_config: ModelConfig) -> nn.Module:
|
||||
model_class = _get_model_architecture(model_config.hf_config)
|
||||
|
||||
# Get the quantization config.
|
||||
quant_config = None
|
||||
if model_config.quantization is not None:
|
||||
if model_class not in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
|
||||
raise ValueError(
|
||||
f"Quantization is not supported for {model_class}.")
|
||||
quant_config = get_quant_config(model_config.quantization,
|
||||
model_config.model,
|
||||
model_config.download_dir)
|
||||
supported_dtypes = quant_config.get_supported_act_dtypes()
|
||||
if model_config.dtype not in supported_dtypes:
|
||||
raise ValueError(
|
||||
f"{model_config.dtype} is not supported for quantization "
|
||||
f"method {model_config.quantization}. Supported dtypes: "
|
||||
f"{supported_dtypes}")
|
||||
|
||||
with _set_default_torch_dtype(model_config.dtype):
|
||||
# Create a model instance.
|
||||
# The weights will be initialized as empty tensors.
|
||||
model = model_class(model_config.hf_config)
|
||||
if model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
|
||||
model = model_class(model_config.hf_config, quant_config)
|
||||
else:
|
||||
model = model_class(model_config.hf_config)
|
||||
if model_config.load_format == "dummy":
|
||||
model = model.cuda()
|
||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||
|
@ -36,13 +36,15 @@ from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (
|
||||
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.model_executor.layers.quantized_linear import ParallelLinear
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.quantization_utils import QuantizationConfig
|
||||
from vllm.model_executor.weight_utils import (
|
||||
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -55,18 +57,21 @@ class LlamaMLP(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = ColumnParallelLinear(hidden_size,
|
||||
2 * intermediate_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.gate_up_proj = ParallelLinear.column(hidden_size,
|
||||
2 * intermediate_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
quant_config=quant_config)
|
||||
self.down_proj = ParallelLinear.row(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
quant_config=quant_config)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@ -87,7 +92,8 @@ class LlamaAttention(nn.Module):
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
@ -103,20 +109,22 @@ class LlamaAttention(nn.Module):
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
self.qkv_proj = ParallelLinear.column(
|
||||
hidden_size,
|
||||
(self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||
self.head_dim,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.o_proj = ParallelLinear.row(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||
self.head_dim,
|
||||
@ -144,7 +152,11 @@ class LlamaAttention(nn.Module):
|
||||
|
||||
class LlamaDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
# Requires transformers > 4.32.0
|
||||
@ -154,11 +166,13 @@ class LlamaDecoderLayer(nn.Module):
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.mlp = LlamaMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
@ -195,7 +209,11 @@ class LlamaDecoderLayer(nn.Module):
|
||||
|
||||
class LlamaModel(nn.Module):
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
@ -205,7 +223,8 @@ class LlamaModel(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
vocab_size, config.hidden_size, perform_initialization=False)
|
||||
self.layers = nn.ModuleList([
|
||||
LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
|
||||
LlamaDecoderLayer(config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
@ -237,16 +256,23 @@ class LlamaModel(nn.Module):
|
||||
|
||||
class LlamaForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = LlamaModel(config)
|
||||
self.quant_config = quant_config
|
||||
self.model = LlamaModel(config, quant_config)
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
||||
vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
# NOTE: The LM head is not quantized.
|
||||
self.lm_head = ParallelLinear.column(config.hidden_size,
|
||||
vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
quant_config=None)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
@ -263,16 +289,28 @@ class LlamaForCausalLM(nn.Module):
|
||||
input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = [
|
||||
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
|
||||
]
|
||||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
||||
_column_parallel_layers = []
|
||||
_row_parallel_layers = ["o_proj", "down_proj"]
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
if self.quant_config is None:
|
||||
weight_suffixes = ["weight"]
|
||||
else:
|
||||
weight_suffixes = self.quant_config.get_tp_tensor_names()
|
||||
|
||||
column_parallel_weights: List[str] = []
|
||||
for layer in self._column_parallel_layers:
|
||||
for suffix in weight_suffixes:
|
||||
column_parallel_weights.append(f"{layer}.{suffix}")
|
||||
row_parallel_weights: List[str] = []
|
||||
for layer in self._row_parallel_layers:
|
||||
for suffix in weight_suffixes:
|
||||
row_parallel_weights.append(f"{layer}.{suffix}")
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
||||
@ -293,11 +331,25 @@ class LlamaForCausalLM(nn.Module):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
is_packed = False
|
||||
is_transposed = False
|
||||
if self.quant_config is not None:
|
||||
is_packed = self.quant_config.is_packed(name)
|
||||
is_transposed = self.quant_config.is_transposed(name)
|
||||
if is_transposed:
|
||||
loaded_weight = loaded_weight.T
|
||||
|
||||
is_attention_weight = False
|
||||
for weight_name, shard_size, offset in attention_weight_specs:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(weight_name, "qkv_proj")]
|
||||
if is_transposed:
|
||||
param = param.T
|
||||
|
||||
if is_packed:
|
||||
shard_size //= self.quant_config.pack_factor
|
||||
offset //= self.quant_config.pack_factor
|
||||
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
@ -316,6 +368,9 @@ class LlamaForCausalLM(nn.Module):
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||
if is_transposed:
|
||||
param = param.T
|
||||
|
||||
shard_size = param.shape[0] // 2
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank:shard_size *
|
||||
@ -330,6 +385,8 @@ class LlamaForCausalLM(nn.Module):
|
||||
continue
|
||||
|
||||
param = state_dict[name]
|
||||
if is_transposed:
|
||||
param = param.T
|
||||
|
||||
if "embed_tokens" in name or "lm_head" in name:
|
||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||
@ -337,6 +394,6 @@ class LlamaForCausalLM(nn.Module):
|
||||
continue
|
||||
|
||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||
self._column_parallel_weights,
|
||||
self._row_parallel_weights,
|
||||
column_parallel_weights,
|
||||
row_parallel_weights,
|
||||
tensor_model_parallel_rank)
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
# Parts of the code here are adapted from PyTorch
|
||||
# repo: https://github.com/pytorch/pytorch
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -16,13 +16,11 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from .mappings import (
|
||||
copy_to_tensor_model_parallel_region,
|
||||
gather_from_tensor_model_parallel_region,
|
||||
reduce_from_tensor_model_parallel_region,
|
||||
scatter_to_tensor_model_parallel_region,
|
||||
)
|
||||
|
||||
from .random import get_cuda_rng_tracker
|
||||
from .utils import (
|
||||
divide,
|
||||
VocabUtility,
|
||||
@ -65,59 +63,6 @@ def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
|
||||
maybe_copy(attribute)
|
||||
|
||||
|
||||
def _initialize_affine_weight_gpu(weight, init_method,
|
||||
partition_dim, stride=1):
|
||||
"""Initialize affine weight for model parallel on GPU."""
|
||||
|
||||
set_tensor_model_parallel_attributes(tensor=weight,
|
||||
is_parallel=True,
|
||||
dim=partition_dim,
|
||||
stride=stride)
|
||||
|
||||
with get_cuda_rng_tracker().fork():
|
||||
init_method(weight)
|
||||
|
||||
|
||||
def _initialize_affine_weight_cpu(weight, output_size, input_size,
|
||||
per_partition_size, partition_dim,
|
||||
init_method, stride=1,
|
||||
return_master_weight=False,
|
||||
*, params_dtype=None):
|
||||
"""Initialize affine weight for model parallel.
|
||||
|
||||
Build the master weight on all processes and scatter
|
||||
the relevant chunk."""
|
||||
|
||||
set_tensor_model_parallel_attributes(tensor=weight,
|
||||
is_parallel=True,
|
||||
dim=partition_dim,
|
||||
stride=stride)
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
# Initialize master weight
|
||||
master_weight = torch.empty(output_size, input_size,
|
||||
dtype=torch.float,
|
||||
requires_grad=False)
|
||||
init_method(master_weight)
|
||||
master_weight = master_weight.to(dtype=params_dtype)
|
||||
|
||||
# Split and copy
|
||||
per_partition_per_stride_size = divide(per_partition_size, stride)
|
||||
weight_list = torch.split(master_weight, per_partition_per_stride_size,
|
||||
dim=partition_dim)
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
my_weight_list = weight_list[rank::world_size]
|
||||
|
||||
with torch.no_grad():
|
||||
torch.cat(my_weight_list, dim=partition_dim, out=weight)
|
||||
if return_master_weight:
|
||||
return master_weight
|
||||
return None
|
||||
|
||||
|
||||
class VocabParallelEmbedding(torch.nn.Module):
|
||||
"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
@ -140,6 +85,9 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
use_cpu_initialization: bool=False,
|
||||
perform_initialization: bool=True):
|
||||
super(VocabParallelEmbedding, self).__init__()
|
||||
assert not perform_initialization
|
||||
assert not use_cpu_initialization
|
||||
|
||||
# Keep the input dimensions.
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
@ -162,24 +110,10 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
self.num_embeddings_per_partition = self.vocab_end_index - \
|
||||
self.vocab_start_index
|
||||
|
||||
# Allocate weights and initialize.
|
||||
if use_cpu_initialization:
|
||||
self.weight = Parameter(torch.empty(
|
||||
self.num_embeddings_per_partition, self.embedding_dim,
|
||||
dtype=params_dtype))
|
||||
if perform_initialization:
|
||||
_initialize_affine_weight_cpu(
|
||||
self.weight, self.num_embeddings, self.embedding_dim,
|
||||
self.num_embeddings_per_partition, 0, init_method,
|
||||
params_dtype=params_dtype)
|
||||
else:
|
||||
self.weight = Parameter(torch.empty(
|
||||
self.num_embeddings_per_partition, self.embedding_dim,
|
||||
device=torch.cuda.current_device(), dtype=params_dtype))
|
||||
if perform_initialization:
|
||||
_initialize_affine_weight_gpu(self.weight, init_method,
|
||||
partition_dim=0, stride=1)
|
||||
|
||||
self.weight = Parameter(torch.empty(
|
||||
self.num_embeddings_per_partition, self.embedding_dim,
|
||||
device=torch.cuda.current_device(), dtype=params_dtype))
|
||||
|
||||
def forward(self, input_):
|
||||
if self.tensor_model_parallel_size > 1:
|
||||
# Build the mask.
|
||||
@ -239,8 +173,11 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
params_dtype=None,
|
||||
use_cpu_initialization=False,
|
||||
perform_initialization=True,
|
||||
quant_config=None,
|
||||
):
|
||||
super(ColumnParallelLinear, self).__init__()
|
||||
assert not perform_initialization
|
||||
assert not use_cpu_initialization
|
||||
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
@ -250,6 +187,7 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
self.world_size = get_tensor_model_parallel_world_size()
|
||||
self.output_size_per_partition = divide(output_size, self.world_size)
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.quant_config = quant_config
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
@ -257,33 +195,13 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
# Parameters.
|
||||
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
||||
# we allocate the transpose.
|
||||
# Initialize weight.
|
||||
if use_cpu_initialization:
|
||||
self.weight = Parameter(torch.empty(self.output_size_per_partition,
|
||||
self.input_size,
|
||||
dtype=params_dtype))
|
||||
if perform_initialization:
|
||||
self.master_weight = _initialize_affine_weight_cpu(
|
||||
self.weight, self.output_size, self.input_size,
|
||||
self.output_size_per_partition, 0, init_method,
|
||||
stride=stride, return_master_weight=keep_master_weight_for_test)
|
||||
else:
|
||||
self.weight = Parameter(torch.empty(
|
||||
self.output_size_per_partition, self.input_size,
|
||||
device=torch.cuda.current_device(), dtype=params_dtype))
|
||||
if perform_initialization:
|
||||
_initialize_affine_weight_gpu(self.weight, init_method,
|
||||
partition_dim=0, stride=stride)
|
||||
self.create_weights(params_dtype)
|
||||
|
||||
if bias:
|
||||
if use_cpu_initialization:
|
||||
self.bias = Parameter(torch.empty(
|
||||
self.output_size_per_partition, dtype=params_dtype))
|
||||
else:
|
||||
self.bias = Parameter(torch.empty(
|
||||
self.output_size_per_partition,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
self.bias = Parameter(torch.empty(
|
||||
self.output_size_per_partition,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
|
||||
# Always initialize bias to zero.
|
||||
with torch.no_grad():
|
||||
@ -291,6 +209,17 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
def create_weights(self, dtype: torch.dtype) -> None:
|
||||
self.weight = Parameter(torch.empty(
|
||||
self.output_size_per_partition, self.input_size,
|
||||
device=torch.cuda.current_device(), dtype=dtype))
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
return F.linear(x, self.weight, bias)
|
||||
|
||||
def forward(self, input_):
|
||||
"""Forward of ColumnParallelLinear
|
||||
@ -306,7 +235,7 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
|
||||
input_parallel = input_
|
||||
# Matrix multiply.
|
||||
output_parallel = F.linear(input_parallel, self.weight, bias)
|
||||
output_parallel = self.apply_weights(input_parallel, bias)
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = gather_from_tensor_model_parallel_region(output_parallel)
|
||||
@ -361,8 +290,11 @@ class RowParallelLinear(torch.nn.Module):
|
||||
use_cpu_initialization=False,
|
||||
perform_initialization=True,
|
||||
reduce_results=True,
|
||||
quant_config=None,
|
||||
):
|
||||
super(RowParallelLinear, self).__init__()
|
||||
assert not perform_initialization
|
||||
assert not use_cpu_initialization
|
||||
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
@ -376,47 +308,32 @@ class RowParallelLinear(torch.nn.Module):
|
||||
self.world_size = get_tensor_model_parallel_world_size()
|
||||
self.input_size_per_partition = divide(input_size, self.world_size)
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.create_weights(params_dtype)
|
||||
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError("When not reduce the results, adding bias to the "
|
||||
"results can lead to incorrect results")
|
||||
|
||||
# Parameters.
|
||||
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
||||
# we allocate the transpose.
|
||||
# Initialize weight.
|
||||
if use_cpu_initialization:
|
||||
self.weight = Parameter(torch.empty(self.output_size,
|
||||
self.input_size_per_partition,
|
||||
dtype=params_dtype))
|
||||
if perform_initialization:
|
||||
self.master_weight = _initialize_affine_weight_cpu(
|
||||
self.weight, self.output_size, self.input_size,
|
||||
self.input_size_per_partition, 1, init_method,
|
||||
stride=stride, return_master_weight=keep_master_weight_for_test,
|
||||
params_dtype=params_dtype)
|
||||
else:
|
||||
self.weight = Parameter(torch.empty(
|
||||
self.output_size, self.input_size_per_partition,
|
||||
device=torch.cuda.current_device(), dtype=params_dtype))
|
||||
if perform_initialization:
|
||||
_initialize_affine_weight_gpu(self.weight, init_method,
|
||||
partition_dim=1, stride=stride)
|
||||
if bias:
|
||||
if use_cpu_initialization:
|
||||
self.bias = Parameter(torch.empty(self.output_size,
|
||||
dtype=params_dtype))
|
||||
else:
|
||||
self.bias = Parameter(torch.empty(
|
||||
self.output_size, device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
self.bias = Parameter(torch.empty(
|
||||
self.output_size, device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
|
||||
# Always initialize bias to zero.
|
||||
with torch.no_grad():
|
||||
self.bias.zero_()
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.weight_t = self.weight.t()
|
||||
|
||||
def create_weights(self, dtype: torch.dtype) -> None:
|
||||
self.weight = Parameter(torch.empty(
|
||||
self.output_size, self.input_size_per_partition,
|
||||
device=torch.cuda.current_device(), dtype=dtype))
|
||||
|
||||
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.linear(x, self.weight)
|
||||
|
||||
def forward(self, input_):
|
||||
"""Forward of RowParallelLinear
|
||||
@ -434,7 +351,7 @@ class RowParallelLinear(torch.nn.Module):
|
||||
else:
|
||||
input_parallel = scatter_to_tensor_model_parallel_region(input_)
|
||||
# Matrix multiply.
|
||||
output_parallel = F.linear(input_parallel, self.weight)
|
||||
output_parallel = self.apply_weights(input_parallel)
|
||||
if self.reduce_results and self.world_size > 1:
|
||||
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
|
||||
else:
|
||||
|
20
vllm/model_executor/quantization_utils/__init__.py
Normal file
20
vllm/model_executor/quantization_utils/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
from typing import Type
|
||||
|
||||
from vllm.model_executor.quantization_utils.awq import AWQConfig
|
||||
from vllm.model_executor.quantization_utils.base import QuantizationConfig
|
||||
|
||||
_QUANTIZATION_REGISTRY = {
|
||||
"awq": AWQConfig,
|
||||
}
|
||||
|
||||
|
||||
def get_quant_class(quantization: str) -> Type[QuantizationConfig]:
|
||||
if quantization not in _QUANTIZATION_REGISTRY:
|
||||
raise ValueError(f"Invalid quantization method: {quantization}")
|
||||
return _QUANTIZATION_REGISTRY[quantization]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"QuantizationConfig",
|
||||
"get_quant_class",
|
||||
]
|
67
vllm/model_executor/quantization_utils/awq.py
Normal file
67
vllm/model_executor/quantization_utils/awq.py
Normal file
@ -0,0 +1,67 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.quantization_utils.base import QuantizationConfig
|
||||
|
||||
|
||||
class AWQConfig(QuantizationConfig):
|
||||
"""Config class for AWQ.
|
||||
|
||||
Reference: https://arxiv.org/abs/2306.00978
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
zero_point: bool,
|
||||
) -> None:
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.zero_point = zero_point
|
||||
|
||||
if self.weight_bits != 4:
|
||||
raise ValueError(
|
||||
"Currently, only 4-bit weight quantization is supported for "
|
||||
f"AWQ, but got {self.weight_bits} bits.")
|
||||
self.pack_factor = 32 // self.weight_bits
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"AWQConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"zero_point={self.zero_point})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "awq"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return [
|
||||
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
|
||||
"quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq # pylint: disable=line-too-long
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
||||
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
|
||||
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||
return cls(weight_bits, group_size, zero_point)
|
||||
|
||||
@classmethod
|
||||
def get_packed_tensor_names(cls) -> List[str]:
|
||||
return ["qweight", "qzeros"]
|
||||
|
||||
@classmethod
|
||||
def get_transposed_tensor_names(cls) -> List[str]:
|
||||
return ["qweight", "qzeros", "scales"]
|
||||
|
||||
@classmethod
|
||||
def get_tp_tensor_names(cls) -> List[str]:
|
||||
return ["qweight", "qzeros", "scales"]
|
65
vllm/model_executor/quantization_utils/base.py
Normal file
65
vllm/model_executor/quantization_utils/base.py
Normal file
@ -0,0 +1,65 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class QuantizationConfig:
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
"""Name of the quantization method."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
"""List of supported activation dtypes."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
"""List of filenames to search for in the model directory."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
|
||||
"""Create a config class from the model's quantization config."""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
|
||||
"""Get a value from the model's quantization config."""
|
||||
for key in keys:
|
||||
if key in config:
|
||||
return config[key]
|
||||
raise ValueError(f"Cannot find any of {keys} in the model's "
|
||||
"quantization config.")
|
||||
|
||||
@classmethod
|
||||
def get_packed_tensor_names(cls) -> List[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def is_packed(cls, tensor_name: str) -> bool:
|
||||
"""Returns True if a tensor is packed.
|
||||
|
||||
A tensor is considered packed if each element in the tensor is a
|
||||
packed representation of multiple elements in the original tensor.
|
||||
For example, an INT32 element in the tensor may represent 8 INT4
|
||||
elements in the original tensor.
|
||||
"""
|
||||
return any(tag in tensor_name for tag in cls.get_packed_tensor_names())
|
||||
|
||||
@classmethod
|
||||
def get_transposed_tensor_names(cls) -> List[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def is_transposed(cls, tensor_name: str) -> bool:
|
||||
"""Returns True if a tensor is transposed relative to nn.Linear.weight.
|
||||
"""
|
||||
return any(tag in tensor_name
|
||||
for tag in cls.get_transposed_tensor_names())
|
||||
|
||||
@classmethod
|
||||
def get_tp_tensor_names(cls) -> List[str]:
|
||||
raise NotImplementedError
|
@ -4,7 +4,7 @@ import glob
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Iterator, List, Optional, Tuple, Any
|
||||
from typing import Any, Iterator, List, Optional, Tuple
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors.torch import load_file, save_file, safe_open
|
||||
@ -13,6 +13,8 @@ import torch
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.quantization_utils import get_quant_class
|
||||
from vllm.model_executor.quantization_utils.base import QuantizationConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -44,7 +46,7 @@ def _shared_pointers(tensors):
|
||||
def convert_bin_to_safetensor_file(
|
||||
pt_filename: str,
|
||||
sf_filename: str,
|
||||
):
|
||||
) -> None:
|
||||
loaded = torch.load(pt_filename, map_location="cpu")
|
||||
if "state_dict" in loaded:
|
||||
loaded = loaded["state_dict"]
|
||||
@ -78,16 +80,55 @@ def convert_bin_to_safetensor_file(
|
||||
raise RuntimeError(f"The output tensors do not match for key {k}")
|
||||
|
||||
|
||||
# TODO(woosuk): Move this to other place.
|
||||
def get_quant_config(
|
||||
quantization: str,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
) -> QuantizationConfig:
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
if not is_local:
|
||||
# Download the config files.
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
hf_folder = snapshot_download(model_name_or_path,
|
||||
allow_patterns="*.json",
|
||||
cache_dir=cache_dir,
|
||||
tqdm_class=Disabledtqdm)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
|
||||
|
||||
quant_cls = get_quant_class(quantization)
|
||||
quant_config_files = [
|
||||
f for f in config_files if any(
|
||||
f.endswith(x) for x in quant_cls.get_config_filenames())
|
||||
]
|
||||
if len(quant_config_files) == 0:
|
||||
raise ValueError(f"Cannot find the config file for {quantization}")
|
||||
if len(quant_config_files) > 1:
|
||||
raise ValueError(f"Found multiple config files for {quantization}: "
|
||||
f"{quant_config_files}")
|
||||
|
||||
quant_config_file = quant_config_files[0]
|
||||
with open(quant_config_file, "r") as f:
|
||||
config = json.load(f)
|
||||
return quant_cls.from_config(config)
|
||||
|
||||
|
||||
def prepare_hf_model_weights(
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_safetensors: bool = False,
|
||||
fall_back_to_pt: bool = True,
|
||||
revision: Optional[str] = None,
|
||||
):
|
||||
) -> Tuple[str, List[str], bool]:
|
||||
# Download model weights from huggingface.
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
allow_patterns = "*.safetensors" if use_safetensors else "*.bin"
|
||||
if use_safetensors:
|
||||
allow_patterns = ["*.safetensors"]
|
||||
else:
|
||||
# Some quantized models use .pt files for storing the weights.
|
||||
allow_patterns = ["*.bin", "*.pt"]
|
||||
if not is_local:
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
@ -99,7 +140,9 @@ def prepare_hf_model_weights(
|
||||
revision=revision)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns))
|
||||
hf_weights_files: List[str] = []
|
||||
for pattern in allow_patterns:
|
||||
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
||||
if not use_safetensors:
|
||||
hf_weights_files = [
|
||||
x for x in hf_weights_files if not x.endswith("training_args.bin")
|
||||
|
Loading…
x
Reference in New Issue
Block a user