Implement AWQ quantization support for LLaMA (#1032)
Co-authored-by: Robert Irvine <robert@seamlessml.com> Co-authored-by: root <rirv938@gmail.com> Co-authored-by: Casper <casperbh.96@gmail.com> Co-authored-by: julian-q <julianhquevedo@gmail.com>
This commit is contained in:
parent
b9fe4616f9
commit
e3e79e9e8a
4
.gitignore
vendored
4
.gitignore
vendored
@ -173,3 +173,7 @@ cython_debug/
|
|||||||
|
|
||||||
# Sphinx documentation
|
# Sphinx documentation
|
||||||
_build/
|
_build/
|
||||||
|
|
||||||
|
# vim swap files
|
||||||
|
*.swo
|
||||||
|
*.swp
|
||||||
|
@ -18,6 +18,7 @@ def main(args: argparse.Namespace):
|
|||||||
llm = LLM(
|
llm = LLM(
|
||||||
model=args.model,
|
model=args.model,
|
||||||
tokenizer=args.tokenizer,
|
tokenizer=args.tokenizer,
|
||||||
|
quantization=args.quantization,
|
||||||
tensor_parallel_size=args.tensor_parallel_size,
|
tensor_parallel_size=args.tensor_parallel_size,
|
||||||
max_num_seqs=args.batch_size,
|
max_num_seqs=args.batch_size,
|
||||||
max_num_batched_tokens=args.batch_size * args.input_len,
|
max_num_batched_tokens=args.batch_size * args.input_len,
|
||||||
@ -63,19 +64,28 @@ def main(args: argparse.Namespace):
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description='Benchmark the latency of processing a single batch of '
|
description='Benchmark the latency of processing a single batch of '
|
||||||
'requests till completion.')
|
'requests till completion.')
|
||||||
parser.add_argument('--model', type=str, default='facebook/opt-125m')
|
parser.add_argument('--model', type=str, default='facebook/opt-125m')
|
||||||
parser.add_argument('--tokenizer', type=str, default=None)
|
parser.add_argument('--tokenizer', type=str, default=None)
|
||||||
|
parser.add_argument('--quantization',
|
||||||
|
'-q',
|
||||||
|
choices=['awq', None],
|
||||||
|
default=None)
|
||||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
||||||
parser.add_argument('--input-len', type=int, default=32)
|
parser.add_argument('--input-len', type=int, default=32)
|
||||||
parser.add_argument('--output-len', type=int, default=128)
|
parser.add_argument('--output-len', type=int, default=128)
|
||||||
parser.add_argument('--batch-size', type=int, default=8)
|
parser.add_argument('--batch-size', type=int, default=8)
|
||||||
parser.add_argument('--n', type=int, default=1,
|
parser.add_argument('--n',
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
help='Number of generated sequences per prompt.')
|
help='Number of generated sequences per prompt.')
|
||||||
parser.add_argument('--use-beam-search', action='store_true')
|
parser.add_argument('--use-beam-search', action='store_true')
|
||||||
parser.add_argument('--num-iters', type=int, default=3,
|
parser.add_argument('--num-iters',
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
help='Number of iterations to run.')
|
help='Number of iterations to run.')
|
||||||
parser.add_argument('--trust-remote-code', action='store_true',
|
parser.add_argument('--trust-remote-code',
|
||||||
|
action='store_true',
|
||||||
help='trust remote code from huggingface')
|
help='trust remote code from huggingface')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -3,7 +3,7 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from typing import List, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
|
from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase
|
||||||
@ -22,15 +22,10 @@ def sample_requests(
|
|||||||
with open(dataset_path) as f:
|
with open(dataset_path) as f:
|
||||||
dataset = json.load(f)
|
dataset = json.load(f)
|
||||||
# Filter out the conversations with less than 2 turns.
|
# Filter out the conversations with less than 2 turns.
|
||||||
dataset = [
|
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||||
data for data in dataset
|
|
||||||
if len(data["conversations"]) >= 2
|
|
||||||
]
|
|
||||||
# Only keep the first two turns of each conversation.
|
# Only keep the first two turns of each conversation.
|
||||||
dataset = [
|
dataset = [(data["conversations"][0]["value"],
|
||||||
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
data["conversations"][1]["value"]) for data in dataset]
|
||||||
for data in dataset
|
|
||||||
]
|
|
||||||
|
|
||||||
# Tokenize the prompts and completions.
|
# Tokenize the prompts and completions.
|
||||||
prompts = [prompt for prompt, _ in dataset]
|
prompts = [prompt for prompt, _ in dataset]
|
||||||
@ -63,6 +58,7 @@ def run_vllm(
|
|||||||
requests: List[Tuple[str, int, int]],
|
requests: List[Tuple[str, int, int]],
|
||||||
model: str,
|
model: str,
|
||||||
tokenizer: str,
|
tokenizer: str,
|
||||||
|
quantization: Optional[str],
|
||||||
tensor_parallel_size: int,
|
tensor_parallel_size: int,
|
||||||
seed: int,
|
seed: int,
|
||||||
n: int,
|
n: int,
|
||||||
@ -72,6 +68,7 @@ def run_vllm(
|
|||||||
llm = LLM(
|
llm = LLM(
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
quantization=quantization,
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
@ -111,8 +108,8 @@ def run_hf(
|
|||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
) -> float:
|
) -> float:
|
||||||
assert not use_beam_search
|
assert not use_beam_search
|
||||||
llm = AutoModelForCausalLM.from_pretrained(model,
|
llm = AutoModelForCausalLM.from_pretrained(
|
||||||
torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
|
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
|
||||||
if llm.config.model_type == "llama":
|
if llm.config.model_type == "llama":
|
||||||
# To enable padding in the HF backend.
|
# To enable padding in the HF backend.
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
@ -132,13 +129,14 @@ def run_hf(
|
|||||||
if len(batch) < max_batch_size and i != len(requests) - 1:
|
if len(batch) < max_batch_size and i != len(requests) - 1:
|
||||||
# Check if we can add more requests to the batch.
|
# Check if we can add more requests to the batch.
|
||||||
_, next_prompt_len, next_output_len = requests[i + 1]
|
_, next_prompt_len, next_output_len = requests[i + 1]
|
||||||
if (max(max_prompt_len, next_prompt_len) + max(
|
if (max(max_prompt_len, next_prompt_len) +
|
||||||
max_output_len, next_output_len)) <= 2048:
|
max(max_output_len, next_output_len)) <= 2048:
|
||||||
# We can add more requests to the batch.
|
# We can add more requests to the batch.
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Generate the sequences.
|
# Generate the sequences.
|
||||||
input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids
|
input_ids = tokenizer(batch, return_tensors="pt",
|
||||||
|
padding=True).input_ids
|
||||||
llm_outputs = llm.generate(
|
llm_outputs = llm.generate(
|
||||||
input_ids=input_ids.cuda(),
|
input_ids=input_ids.cuda(),
|
||||||
do_sample=not use_beam_search,
|
do_sample=not use_beam_search,
|
||||||
@ -165,44 +163,58 @@ def main(args: argparse.Namespace):
|
|||||||
random.seed(args.seed)
|
random.seed(args.seed)
|
||||||
|
|
||||||
# Sample the requests.
|
# Sample the requests.
|
||||||
tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
|
tokenizer = get_tokenizer(args.tokenizer,
|
||||||
|
trust_remote_code=args.trust_remote_code)
|
||||||
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
||||||
|
|
||||||
if args.backend == "vllm":
|
if args.backend == "vllm":
|
||||||
elapsed_time = run_vllm(
|
elapsed_time = run_vllm(requests, args.model, args.tokenizer,
|
||||||
requests, args.model, args.tokenizer, args.tensor_parallel_size,
|
args.quantization, args.tensor_parallel_size,
|
||||||
args.seed, args.n, args.use_beam_search, args.trust_remote_code)
|
args.seed, args.n, args.use_beam_search,
|
||||||
|
args.trust_remote_code)
|
||||||
elif args.backend == "hf":
|
elif args.backend == "hf":
|
||||||
assert args.tensor_parallel_size == 1
|
assert args.tensor_parallel_size == 1
|
||||||
elapsed_time = run_hf(
|
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||||
requests, args.model, tokenizer, args.n, args.use_beam_search,
|
args.use_beam_search, args.hf_max_batch_size,
|
||||||
args.hf_max_batch_size, args.trust_remote_code)
|
args.trust_remote_code)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown backend: {args.backend}")
|
raise ValueError(f"Unknown backend: {args.backend}")
|
||||||
total_num_tokens = sum(
|
total_num_tokens = sum(prompt_len + output_len
|
||||||
prompt_len + output_len
|
for _, prompt_len, output_len in requests)
|
||||||
for _, prompt_len, output_len in requests
|
|
||||||
)
|
|
||||||
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
|
||||||
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
|
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
|
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
|
||||||
parser.add_argument("--backend", type=str, choices=["vllm", "hf"],
|
parser.add_argument("--backend",
|
||||||
|
type=str,
|
||||||
|
choices=["vllm", "hf"],
|
||||||
default="vllm")
|
default="vllm")
|
||||||
parser.add_argument("--dataset", type=str, required=True,
|
parser.add_argument("--dataset",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
help="Path to the dataset.")
|
help="Path to the dataset.")
|
||||||
parser.add_argument("--model", type=str, default="facebook/opt-125m")
|
parser.add_argument("--model", type=str, default="facebook/opt-125m")
|
||||||
parser.add_argument("--tokenizer", type=str, default=None)
|
parser.add_argument("--tokenizer", type=str, default=None)
|
||||||
|
parser.add_argument('--quantization',
|
||||||
|
'-q',
|
||||||
|
choices=['awq', None],
|
||||||
|
default=None)
|
||||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
||||||
parser.add_argument("--n", type=int, default=1,
|
parser.add_argument("--n",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
help="Number of generated sequences per prompt.")
|
help="Number of generated sequences per prompt.")
|
||||||
parser.add_argument("--use-beam-search", action="store_true")
|
parser.add_argument("--use-beam-search", action="store_true")
|
||||||
parser.add_argument("--num-prompts", type=int, default=1000,
|
parser.add_argument("--num-prompts",
|
||||||
|
type=int,
|
||||||
|
default=1000,
|
||||||
help="Number of prompts to process.")
|
help="Number of prompts to process.")
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
parser.add_argument("--hf-max-batch-size", type=int, default=None,
|
parser.add_argument("--hf-max-batch-size",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
help="Maximum batch size for HF backend.")
|
help="Maximum batch size for HF backend.")
|
||||||
parser.add_argument('--trust-remote-code',
|
parser.add_argument('--trust-remote-code',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
@ -215,6 +227,8 @@ if __name__ == "__main__":
|
|||||||
elif args.backend == "hf":
|
elif args.backend == "hf":
|
||||||
if args.hf_max_batch_size is None:
|
if args.hf_max_batch_size is None:
|
||||||
raise ValueError("HF max batch size is required for HF backend.")
|
raise ValueError("HF max batch size is required for HF backend.")
|
||||||
|
if args.quantization is not None:
|
||||||
|
raise ValueError("Quantization is only for vLLM backend.")
|
||||||
if args.tokenizer is None:
|
if args.tokenizer is None:
|
||||||
args.tokenizer = args.model
|
args.tokenizer = args.model
|
||||||
|
|
||||||
|
15
csrc/quantization.cpp
Normal file
15
csrc/quantization.cpp
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
#include <torch/extension.h>
|
||||||
|
|
||||||
|
torch::Tensor awq_gemm(
|
||||||
|
torch::Tensor _in_feats,
|
||||||
|
torch::Tensor _kernel,
|
||||||
|
torch::Tensor _scaling_factors,
|
||||||
|
torch::Tensor _zeros,
|
||||||
|
int split_k_iters);
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||||
|
m.def(
|
||||||
|
"awq_gemm",
|
||||||
|
&awq_gemm,
|
||||||
|
"Quantized GEMM for AWQ");
|
||||||
|
}
|
79
csrc/quantization/awq/dequantize.cuh
Normal file
79
csrc/quantization/awq/dequantize.cuh
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
/*
|
||||||
|
Adapted from https://github.com/mit-han-lab/llm-awq
|
||||||
|
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||||
|
@article{lin2023awq,
|
||||||
|
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
|
||||||
|
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
|
||||||
|
journal={arXiv},
|
||||||
|
year={2023}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
|
||||||
|
__device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
|
||||||
|
{
|
||||||
|
uint4 result;
|
||||||
|
|
||||||
|
uint32_t* h = reinterpret_cast<uint32_t*>(&result);
|
||||||
|
uint32_t const i4s = reinterpret_cast<uint32_t const&>(source);
|
||||||
|
|
||||||
|
// First, we extract the i4s and construct an intermediate fp16 number.
|
||||||
|
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
|
||||||
|
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
|
||||||
|
static constexpr uint32_t TOP_MASK = 0x00f000f0;
|
||||||
|
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||||
|
|
||||||
|
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
|
||||||
|
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
|
||||||
|
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
|
||||||
|
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
|
||||||
|
|
||||||
|
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
|
||||||
|
// immediately before required.
|
||||||
|
const uint32_t top_i4s = i4s >> 8;
|
||||||
|
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
|
||||||
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
|
: "=r"(h[0])
|
||||||
|
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||||
|
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
|
||||||
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
|
: "=r"(h[1])
|
||||||
|
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||||
|
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
|
||||||
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
|
: "=r"(h[2])
|
||||||
|
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||||
|
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
|
||||||
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
|
: "=r"(h[3])
|
||||||
|
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
|
||||||
|
|
||||||
|
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
|
||||||
|
// half2 ctor. In this case, I chose performance reliability over code readability.
|
||||||
|
|
||||||
|
// This is the half2 {1032, 1032} represented as an integer.
|
||||||
|
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
|
||||||
|
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
|
||||||
|
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
|
||||||
|
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
|
||||||
|
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
|
||||||
|
// This is the half2 {-72, -72} represented as an integer.
|
||||||
|
// static constexpr uint32_t NEG_72 = 0xd480d480;
|
||||||
|
// Haotian: Let's use {-64, -64}.
|
||||||
|
static constexpr uint32_t NEG_64 = 0xd400d400;
|
||||||
|
|
||||||
|
// Finally, we construct the output numbers.
|
||||||
|
// Convert elt_01
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
|
||||||
|
// Convert elt_23
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||||
|
// Convert elt_45
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
|
||||||
|
// Convert elt_67
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
477
csrc/quantization/awq/gemm_kernels.cu
Normal file
477
csrc/quantization/awq/gemm_kernels.cu
Normal file
@ -0,0 +1,477 @@
|
|||||||
|
/*
|
||||||
|
Adapted from https://github.com/mit-han-lab/llm-awq
|
||||||
|
@article{lin2023awq,
|
||||||
|
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
|
||||||
|
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
|
||||||
|
journal={arXiv},
|
||||||
|
year={2023}
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
#include <torch/extension.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
|
#include "dequantize.cuh"
|
||||||
|
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
// Pack two half values.
|
||||||
|
static inline __device__ __host__ unsigned
|
||||||
|
__pack_half2(const half x, const half y) {
|
||||||
|
unsigned v0 = *((unsigned short *)&x);
|
||||||
|
unsigned v1 = *((unsigned short *)&y);
|
||||||
|
return (v1 << 16) | v0;
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
|
||||||
|
{
|
||||||
|
static constexpr uint32_t ZERO = 0x0;
|
||||||
|
float C_warp[32];
|
||||||
|
__shared__ half A_shared[16 * (32 + 8)];
|
||||||
|
__shared__ half B_shared[32 * (128 + 8)];
|
||||||
|
|
||||||
|
__shared__ half scaling_factors_shared[128];
|
||||||
|
__shared__ half zeros_shared[128];
|
||||||
|
|
||||||
|
int j_factors1 = ((OC + 128 - 1) / 128);
|
||||||
|
int blockIdx_x = 0;
|
||||||
|
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
|
||||||
|
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
|
||||||
|
|
||||||
|
half A_shared_warp[8];
|
||||||
|
half B_shared_warp[32];
|
||||||
|
for (int j_0_4_init = 0; j_0_4_init < 4; ++j_0_4_init) {
|
||||||
|
for (int i = 0; i < 8; ++i) {
|
||||||
|
C_warp[(j_0_4_init * 8) + i] = 0.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr int row_stride_warp = 32 * 8 / 32;
|
||||||
|
static constexpr int row_stride = 2 * 32 * 8 / 128;
|
||||||
|
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 128;
|
||||||
|
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||||
|
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
|
||||||
|
// bool wb_C_flag = (threadIdx.x / 4) < M;
|
||||||
|
|
||||||
|
half* A_ptr = A
|
||||||
|
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
|
||||||
|
+ (((int)threadIdx.x) % (32 / 8)) * 8;
|
||||||
|
|
||||||
|
int* B_ptr = B
|
||||||
|
+ ((int)threadIdx.y) * (OC / 8) * 2
|
||||||
|
+ (((int)threadIdx.x) / (128 / 8)) * (OC / 8)
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * (128 / 8)
|
||||||
|
+ (((int)threadIdx.x) % (128 / 8)) * 1;
|
||||||
|
// Why * 1 in the above line?
|
||||||
|
|
||||||
|
half* A_shared_ptr = A_shared
|
||||||
|
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
|
||||||
|
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
|
||||||
|
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
|
||||||
|
|
||||||
|
half* B_shared_ptr = B_shared
|
||||||
|
+ ((int)threadIdx.y) * (row_stride / 2) * (128 + 8)
|
||||||
|
+ (((int)threadIdx.x) / (128 / 8)) * (128 + 8)
|
||||||
|
+ (((int)threadIdx.x) % (128 / 8)) * 8;
|
||||||
|
|
||||||
|
int* zeros_ptr = zeros
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * (128 / 8)
|
||||||
|
+ ((int)threadIdx.x) % (128 / 8);
|
||||||
|
|
||||||
|
half* scaling_factors_ptr = scaling_factors
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * (128)
|
||||||
|
+ (((int)threadIdx.x) % (128 / 8)) * 8;
|
||||||
|
|
||||||
|
half* C_ptr = C
|
||||||
|
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * 128
|
||||||
|
+ ((int)threadIdx.y) * 64
|
||||||
|
+ (((int)threadIdx.x) % 4) * 2;
|
||||||
|
|
||||||
|
// preload s.f. and zeros
|
||||||
|
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
|
||||||
|
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
|
||||||
|
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
|
||||||
|
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
|
||||||
|
__syncthreads();
|
||||||
|
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||||
|
if (ld_A_flag)
|
||||||
|
{
|
||||||
|
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
|
||||||
|
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
|
||||||
|
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
|
||||||
|
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
|
||||||
|
/*
|
||||||
|
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
|
||||||
|
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
|
||||||
|
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
|
||||||
|
|
||||||
|
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) {
|
||||||
|
|
||||||
|
// B: 32 x 136 (128+8) float16
|
||||||
|
// each warp: 32 x 4
|
||||||
|
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
|
||||||
|
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
|
||||||
|
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
|
||||||
|
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
|
||||||
|
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||||
|
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||||
|
|
||||||
|
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||||
|
// - zero and * scale
|
||||||
|
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||||
|
/*
|
||||||
|
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
|
||||||
|
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
// write back
|
||||||
|
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (128 + 8)) = B_loaded_fp16;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) {
|
||||||
|
{
|
||||||
|
unsigned int addr;
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||||
|
: "=r"(addr)
|
||||||
|
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||||
|
);
|
||||||
|
|
||||||
|
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
||||||
|
"{%0, %1, %2, %3}, [%4];\n"
|
||||||
|
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
|
||||||
|
: "r"(addr)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int ax1_0 = 0; ax1_0 < 4; ++ax1_0) {
|
||||||
|
{
|
||||||
|
unsigned int addr;
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||||
|
: "=r"(addr)
|
||||||
|
: "l"((void *)((&(B_shared[(((k_0_1 * 2176) + (((int)threadIdx.y) * 64)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 136) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||||
|
);
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
||||||
|
"{%0, %1, %2, %3}, [%4];\n"
|
||||||
|
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
|
||||||
|
: "r"(addr)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) {
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||||
|
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||||
|
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Shang: Hoist loop invariance.
|
||||||
|
for (int ax1_0_1 = 0; ax1_0_1 < 4; ++ax1_0_1) {
|
||||||
|
for (int local_id = 0; local_id < 8; ++local_id) {
|
||||||
|
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
|
||||||
|
if (row_offset < M)
|
||||||
|
{
|
||||||
|
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C)
|
||||||
|
{
|
||||||
|
static constexpr uint32_t ZERO = 0x0;
|
||||||
|
float C_warp[32];
|
||||||
|
__shared__ half A_shared[16 * (32 + 8)];
|
||||||
|
__shared__ half B_shared[32 * (64 + 8)];
|
||||||
|
|
||||||
|
__shared__ half scaling_factors_shared[64];
|
||||||
|
__shared__ half zeros_shared[64];
|
||||||
|
|
||||||
|
int j_factors1 = ((OC + 64 - 1) / 64);
|
||||||
|
|
||||||
|
int blockIdx_x = 0;
|
||||||
|
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
|
||||||
|
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
|
||||||
|
|
||||||
|
half A_shared_warp[8];
|
||||||
|
half B_shared_warp[16];
|
||||||
|
for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) {
|
||||||
|
for (int i = 0; i < 8; ++i) {
|
||||||
|
C_warp[(j_0_4_init * 8) + i] = 0.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr int row_stride_warp = 32 * 8 / 32;
|
||||||
|
static constexpr int row_stride = 2 * 32 * 8 / 64;
|
||||||
|
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64;
|
||||||
|
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||||
|
bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id
|
||||||
|
// bool wb_C_flag = (threadIdx.x / 4) < M;
|
||||||
|
|
||||||
|
half* A_ptr = A
|
||||||
|
+ (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
|
||||||
|
+ (((int)threadIdx.x) % (32 / 8)) * 8;
|
||||||
|
|
||||||
|
int* B_ptr = B
|
||||||
|
+ ((int)threadIdx.y) * (OC / 8) * 4
|
||||||
|
+ (((int)threadIdx.x) / (64 / 8)) * (OC / 8)
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
|
||||||
|
+ (((int)threadIdx.x) % (64 / 8)) * 1;
|
||||||
|
// Why * 1 in the above line?
|
||||||
|
|
||||||
|
half* A_shared_ptr = A_shared
|
||||||
|
+ ((int)threadIdx.y) * row_stride_warp * (32 + 8)
|
||||||
|
+ (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
|
||||||
|
+ (((int)threadIdx.x) % (32 / 8) ) * 8;
|
||||||
|
|
||||||
|
half* B_shared_ptr = B_shared
|
||||||
|
+ ((int)threadIdx.y) * (row_stride / 2) * (64 + 8)
|
||||||
|
+ (((int)threadIdx.x) / (64 / 8)) * (64 + 8)
|
||||||
|
+ (((int)threadIdx.x) % (64 / 8)) * 8;
|
||||||
|
|
||||||
|
int* zeros_ptr = zeros
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * (64 / 8)
|
||||||
|
+ ((int)threadIdx.x) % (64 / 8);
|
||||||
|
|
||||||
|
half* scaling_factors_ptr = scaling_factors
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * (64)
|
||||||
|
+ (((int)threadIdx.x) % (64 / 8)) * 8;
|
||||||
|
|
||||||
|
half* C_ptr = C
|
||||||
|
+ blockIdx_z * M * OC // blockIdz.x -> split_k dim
|
||||||
|
+ (((int)blockIdx_y) % j_factors1) * 64
|
||||||
|
+ ((int)threadIdx.y) * 32
|
||||||
|
+ (((int)threadIdx.x) % 4) * 2;
|
||||||
|
|
||||||
|
// preload s.f. and zeros
|
||||||
|
int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
|
||||||
|
if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
|
||||||
|
for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
|
||||||
|
int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
|
||||||
|
__syncthreads();
|
||||||
|
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||||
|
if (ld_A_flag)
|
||||||
|
{
|
||||||
|
*(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
*(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
|
||||||
|
uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
|
||||||
|
uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
|
||||||
|
uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
|
||||||
|
/*
|
||||||
|
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
|
||||||
|
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
|
||||||
|
int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
|
||||||
|
|
||||||
|
for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) {
|
||||||
|
|
||||||
|
// B: 32 x 136 (128+8) float16
|
||||||
|
// each warp: 32 x 4
|
||||||
|
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
|
||||||
|
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
|
||||||
|
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
|
||||||
|
uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
|
||||||
|
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||||
|
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||||
|
|
||||||
|
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
|
||||||
|
// - zero and * scale
|
||||||
|
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
|
||||||
|
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
|
||||||
|
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
|
||||||
|
/*
|
||||||
|
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
|
||||||
|
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
// write back
|
||||||
|
*(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_fp16;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1)
|
||||||
|
{
|
||||||
|
{
|
||||||
|
unsigned int addr;
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||||
|
: "=r"(addr)
|
||||||
|
: "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||||
|
);
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
||||||
|
"{%0, %1, %2, %3}, [%4];\n"
|
||||||
|
: "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
|
||||||
|
: "r"(addr)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0)
|
||||||
|
{
|
||||||
|
{
|
||||||
|
unsigned int addr;
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
|
||||||
|
: "=r"(addr)
|
||||||
|
: "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8))))
|
||||||
|
);
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
||||||
|
"{%0, %1, %2, %3}, [%4];\n"
|
||||||
|
: "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
|
||||||
|
: "r"(addr)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4)
|
||||||
|
{
|
||||||
|
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||||
|
: "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
__asm__ __volatile__(
|
||||||
|
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
|
||||||
|
: "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
|
||||||
|
: "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Shang: Hoist loop invariance.
|
||||||
|
for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) {
|
||||||
|
for (int local_id = 0; local_id < 8; ++local_id) {
|
||||||
|
int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
|
||||||
|
if (row_offset < M)
|
||||||
|
{
|
||||||
|
*(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// in_feats: M, IC [float16]
|
||||||
|
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
|
||||||
|
// scaling_factors: IC // G, OC [float16]
|
||||||
|
// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
|
||||||
|
// assume that batch_size < 16 for now
|
||||||
|
|
||||||
|
torch::Tensor awq_gemm(
|
||||||
|
torch::Tensor _in_feats,
|
||||||
|
torch::Tensor _kernel,
|
||||||
|
torch::Tensor _scaling_factors,
|
||||||
|
torch::Tensor _zeros,
|
||||||
|
int split_k_iters)
|
||||||
|
{
|
||||||
|
int num_in_feats = _in_feats.size(0);
|
||||||
|
int num_in_channels = _in_feats.size(1);
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats));
|
||||||
|
|
||||||
|
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
|
||||||
|
at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options);
|
||||||
|
int num_out_feats = _out_feats.size(-2);
|
||||||
|
int num_out_channels = _out_feats.size(-1);
|
||||||
|
|
||||||
|
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
|
||||||
|
auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
|
||||||
|
auto out_feats = reinterpret_cast<half*>(_out_feats.data_ptr<at::Half>());
|
||||||
|
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
|
||||||
|
auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
|
||||||
|
int group_size = num_in_channels / _scaling_factors.size(0);
|
||||||
|
|
||||||
|
if (num_out_channels % 64 != 0)
|
||||||
|
throw std::invalid_argument("OC is not multiple of cta_N = 64");
|
||||||
|
if (num_out_channels % 8 != 0)
|
||||||
|
throw std::invalid_argument("OC is not multiple of pack_num = 8");
|
||||||
|
if (group_size % 32 != 0)
|
||||||
|
throw std::invalid_argument("Group size should be a multiple of 32");
|
||||||
|
if (num_out_channels % group_size != 0)
|
||||||
|
throw std::invalid_argument("OC is not multiple of Group size");
|
||||||
|
|
||||||
|
if (num_out_channels % 128 == 0)
|
||||||
|
{
|
||||||
|
int j_factors1 = num_out_channels / 128 / 1;
|
||||||
|
dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
||||||
|
// threadIdx.x: 32
|
||||||
|
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||||
|
dim3 threads_per_block(32, 2);
|
||||||
|
gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block>>>(
|
||||||
|
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
|
||||||
|
}
|
||||||
|
else if (num_out_channels % 64 == 0)
|
||||||
|
{
|
||||||
|
int j_factors1 = num_out_channels / 64 / 1;
|
||||||
|
dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters);
|
||||||
|
|
||||||
|
// threadIdx.x: 32
|
||||||
|
// threadIdx.y: i_factors[2] * j_factors[2]
|
||||||
|
dim3 threads_per_block(32, 2);
|
||||||
|
gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block>>>(
|
||||||
|
group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
|
||||||
|
}
|
||||||
|
return _out_feats.sum(0);
|
||||||
|
}
|
14
setup.py
14
setup.py
@ -146,6 +146,20 @@ activation_extension = CUDAExtension(
|
|||||||
)
|
)
|
||||||
ext_modules.append(activation_extension)
|
ext_modules.append(activation_extension)
|
||||||
|
|
||||||
|
# Quantization kernels.
|
||||||
|
quantization_extension = CUDAExtension(
|
||||||
|
name="vllm.quantization_ops",
|
||||||
|
sources=[
|
||||||
|
"csrc/quantization.cpp",
|
||||||
|
"csrc/quantization/awq/gemm_kernels.cu",
|
||||||
|
],
|
||||||
|
extra_compile_args={
|
||||||
|
"cxx": CXX_FLAGS,
|
||||||
|
"nvcc": NVCC_FLAGS,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
ext_modules.append(quantization_extension)
|
||||||
|
|
||||||
|
|
||||||
def get_path(*filepath) -> str:
|
def get_path(*filepath) -> str:
|
||||||
return os.path.join(ROOT_DIR, *filepath)
|
return os.path.join(ROOT_DIR, *filepath)
|
||||||
|
@ -43,6 +43,8 @@ class ModelConfig:
|
|||||||
version.
|
version.
|
||||||
max_model_len: Maximum length of a sequence (including prompt and
|
max_model_len: Maximum length of a sequence (including prompt and
|
||||||
output). If None, will be derived from the model.
|
output). If None, will be derived from the model.
|
||||||
|
quantization: Quantization method that was used to quantize the model
|
||||||
|
weights. If None, we assume the model weights are not quantized.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -57,6 +59,7 @@ class ModelConfig:
|
|||||||
seed: int,
|
seed: int,
|
||||||
revision: Optional[str],
|
revision: Optional[str],
|
||||||
max_model_len: Optional[int] = None,
|
max_model_len: Optional[int] = None,
|
||||||
|
quantization: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
@ -66,11 +69,13 @@ class ModelConfig:
|
|||||||
self.load_format = load_format
|
self.load_format = load_format
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.revision = revision
|
self.revision = revision
|
||||||
|
self.quantization = quantization
|
||||||
|
|
||||||
self.hf_config = get_config(model, trust_remote_code, revision)
|
self.hf_config = get_config(model, trust_remote_code, revision)
|
||||||
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
||||||
self._verify_load_format()
|
self._verify_load_format()
|
||||||
self._verify_tokenizer_mode()
|
self._verify_tokenizer_mode()
|
||||||
|
self._verify_quantization()
|
||||||
self.max_model_len = None
|
self.max_model_len = None
|
||||||
if max_model_len is not None:
|
if max_model_len is not None:
|
||||||
derived_max_model_len = self.get_max_model_len()
|
derived_max_model_len = self.get_max_model_len()
|
||||||
@ -100,6 +105,17 @@ class ModelConfig:
|
|||||||
"either 'auto' or 'slow'.")
|
"either 'auto' or 'slow'.")
|
||||||
self.tokenizer_mode = tokenizer_mode
|
self.tokenizer_mode = tokenizer_mode
|
||||||
|
|
||||||
|
def _verify_quantization(self) -> None:
|
||||||
|
supported_quantization = ["awq"]
|
||||||
|
if self.quantization is None:
|
||||||
|
return
|
||||||
|
quantization = self.quantization.lower()
|
||||||
|
if quantization not in supported_quantization:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown quantization: {self.quantization}. Must be one of "
|
||||||
|
f"{supported_quantization}.")
|
||||||
|
self.quantization = quantization
|
||||||
|
|
||||||
def verify_with_parallel_config(
|
def verify_with_parallel_config(
|
||||||
self,
|
self,
|
||||||
parallel_config: "ParallelConfig",
|
parallel_config: "ParallelConfig",
|
||||||
|
@ -29,6 +29,7 @@ class EngineArgs:
|
|||||||
max_num_seqs: int = 256
|
max_num_seqs: int = 256
|
||||||
disable_log_stats: bool = False
|
disable_log_stats: bool = False
|
||||||
revision: Optional[str] = None
|
revision: Optional[str] = None
|
||||||
|
quantization: Optional[str] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.tokenizer is None:
|
if self.tokenizer is None:
|
||||||
@ -88,7 +89,6 @@ class EngineArgs:
|
|||||||
'a numpy cache to speed up the loading. '
|
'a numpy cache to speed up the loading. '
|
||||||
'"dummy" will initialize the weights with random values, '
|
'"dummy" will initialize the weights with random values, '
|
||||||
'which is mainly for profiling.')
|
'which is mainly for profiling.')
|
||||||
# TODO(woosuk): Support FP32.
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--dtype',
|
'--dtype',
|
||||||
type=str,
|
type=str,
|
||||||
@ -150,6 +150,13 @@ class EngineArgs:
|
|||||||
parser.add_argument('--disable-log-stats',
|
parser.add_argument('--disable-log-stats',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='disable logging statistics')
|
help='disable logging statistics')
|
||||||
|
# Quantization settings.
|
||||||
|
parser.add_argument('--quantization',
|
||||||
|
'-q',
|
||||||
|
type=str,
|
||||||
|
choices=['awq', None],
|
||||||
|
default=None,
|
||||||
|
help='Method used to quantize the weights')
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -163,12 +170,11 @@ class EngineArgs:
|
|||||||
def create_engine_configs(
|
def create_engine_configs(
|
||||||
self,
|
self,
|
||||||
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
|
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
|
||||||
# Initialize the configs.
|
|
||||||
model_config = ModelConfig(self.model, self.tokenizer,
|
model_config = ModelConfig(self.model, self.tokenizer,
|
||||||
self.tokenizer_mode, self.trust_remote_code,
|
self.tokenizer_mode, self.trust_remote_code,
|
||||||
self.download_dir, self.load_format,
|
self.download_dir, self.load_format,
|
||||||
self.dtype, self.seed, self.revision,
|
self.dtype, self.seed, self.revision,
|
||||||
self.max_model_len)
|
self.max_model_len, self.quantization)
|
||||||
cache_config = CacheConfig(self.block_size,
|
cache_config = CacheConfig(self.block_size,
|
||||||
self.gpu_memory_utilization,
|
self.gpu_memory_utilization,
|
||||||
self.swap_space)
|
self.swap_space)
|
||||||
|
@ -80,6 +80,7 @@ class LLMEngine:
|
|||||||
f"download_dir={model_config.download_dir!r}, "
|
f"download_dir={model_config.download_dir!r}, "
|
||||||
f"load_format={model_config.load_format}, "
|
f"load_format={model_config.load_format}, "
|
||||||
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
|
f"tensor_parallel_size={parallel_config.tensor_parallel_size}, "
|
||||||
|
f"quantization={model_config.quantization}, "
|
||||||
f"seed={model_config.seed})")
|
f"seed={model_config.seed})")
|
||||||
# TODO(woosuk): Print more configs in debug mode.
|
# TODO(woosuk): Print more configs in debug mode.
|
||||||
|
|
||||||
|
37
vllm/model_executor/layers/quantized_linear/__init__.py
Normal file
37
vllm/model_executor/layers/quantized_linear/__init__.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
from vllm.model_executor.layers.quantized_linear.awq import (
|
||||||
|
AWQColumnParallelLinear, AWQRowParallelLinear)
|
||||||
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
|
ColumnParallelLinear, RowParallelLinear)
|
||||||
|
|
||||||
|
_QUANTIZED_LINEAR_REGISTRY = {
|
||||||
|
"awq": (AWQColumnParallelLinear, AWQRowParallelLinear),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelLinear:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def column(cls, *args, **kwargs) -> ColumnParallelLinear:
|
||||||
|
quant_config = kwargs.get("quant_config", None)
|
||||||
|
if quant_config is None:
|
||||||
|
return ColumnParallelLinear(*args, **kwargs)
|
||||||
|
|
||||||
|
name = quant_config.get_name()
|
||||||
|
if name not in _QUANTIZED_LINEAR_REGISTRY:
|
||||||
|
raise ValueError(f"No quantized linear is found for {name}")
|
||||||
|
|
||||||
|
quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][0]
|
||||||
|
return quant_linear_cls(*args, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def row(cls, *args, **kwargs) -> RowParallelLinear:
|
||||||
|
quant_config = kwargs.get("quant_config", None)
|
||||||
|
if quant_config is None:
|
||||||
|
return RowParallelLinear(*args, **kwargs)
|
||||||
|
|
||||||
|
name = quant_config.get_name()
|
||||||
|
if name not in _QUANTIZED_LINEAR_REGISTRY:
|
||||||
|
raise ValueError(f"No quantized linear is found for {name}")
|
||||||
|
|
||||||
|
quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][1]
|
||||||
|
return quant_linear_cls(*args, **kwargs)
|
102
vllm/model_executor/layers/quantized_linear/awq.py
Normal file
102
vllm/model_executor/layers/quantized_linear/awq.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from vllm import quantization_ops
|
||||||
|
from vllm.model_executor.parallel_utils.tensor_parallel.layers import (
|
||||||
|
ColumnParallelLinear, RowParallelLinear)
|
||||||
|
|
||||||
|
|
||||||
|
class AWQColumnParallelLinear(ColumnParallelLinear):
|
||||||
|
|
||||||
|
def create_weights(self, dtype: torch.dtype) -> None:
|
||||||
|
assert self.input_size % self.quant_config.weight_bits == 0
|
||||||
|
assert (self.output_size_per_partition %
|
||||||
|
self.quant_config.pack_factor == 0)
|
||||||
|
self.qweight = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
self.input_size,
|
||||||
|
self.output_size_per_partition //
|
||||||
|
self.quant_config.pack_factor,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
self.qzeros = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
self.input_size // self.quant_config.group_size,
|
||||||
|
self.output_size_per_partition //
|
||||||
|
self.quant_config.pack_factor,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
self.scales = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
self.input_size // self.quant_config.group_size,
|
||||||
|
self.output_size_per_partition,
|
||||||
|
device="cuda",
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply_weights(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
pack_factor = self.quant_config.pack_factor
|
||||||
|
out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor)
|
||||||
|
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||||
|
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
|
||||||
|
self.qzeros, pack_factor)
|
||||||
|
if bias is not None:
|
||||||
|
out = out + bias
|
||||||
|
return out.reshape(out_shape)
|
||||||
|
|
||||||
|
|
||||||
|
class AWQRowParallelLinear(RowParallelLinear):
|
||||||
|
|
||||||
|
def create_weights(self, dtype: torch.dtype) -> None:
|
||||||
|
assert (self.input_size_per_partition %
|
||||||
|
self.quant_config.weight_bits == 0)
|
||||||
|
assert self.output_size % self.quant_config.pack_factor == 0
|
||||||
|
self.qweight = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
self.input_size_per_partition,
|
||||||
|
self.output_size // self.quant_config.pack_factor,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
self.qzeros = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
self.input_size_per_partition // self.quant_config.group_size,
|
||||||
|
self.output_size // self.quant_config.pack_factor,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
self.scales = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
self.input_size_per_partition // self.quant_config.group_size,
|
||||||
|
self.output_size,
|
||||||
|
device="cuda",
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
pack_factor = self.quant_config.pack_factor
|
||||||
|
out_shape = (x.shape[-2], self.qweight.shape[-1] * pack_factor)
|
||||||
|
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||||
|
out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales,
|
||||||
|
self.qzeros, pack_factor)
|
||||||
|
return out.reshape(out_shape)
|
@ -8,7 +8,8 @@ from transformers import PretrainedConfig
|
|||||||
|
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
from vllm.model_executor.models import * # pylint: disable=wildcard-import
|
from vllm.model_executor.models import * # pylint: disable=wildcard-import
|
||||||
from vllm.model_executor.weight_utils import initialize_dummy_weights
|
from vllm.model_executor.weight_utils import (get_quant_config,
|
||||||
|
initialize_dummy_weights)
|
||||||
|
|
||||||
# TODO(woosuk): Lazy-load the model classes.
|
# TODO(woosuk): Lazy-load the model classes.
|
||||||
_MODEL_REGISTRY = {
|
_MODEL_REGISTRY = {
|
||||||
@ -30,6 +31,11 @@ _MODEL_REGISTRY = {
|
|||||||
"RWForCausalLM": FalconForCausalLM,
|
"RWForCausalLM": FalconForCausalLM,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# FIXME(woosuk): Remove this once all models support quantization.
|
||||||
|
_MODEL_CLASSES_SUPPORT_QUANTIZATION = [
|
||||||
|
LlamaForCausalLM,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def _set_default_torch_dtype(dtype: torch.dtype):
|
def _set_default_torch_dtype(dtype: torch.dtype):
|
||||||
@ -52,10 +58,30 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
|||||||
|
|
||||||
def get_model(model_config: ModelConfig) -> nn.Module:
|
def get_model(model_config: ModelConfig) -> nn.Module:
|
||||||
model_class = _get_model_architecture(model_config.hf_config)
|
model_class = _get_model_architecture(model_config.hf_config)
|
||||||
|
|
||||||
|
# Get the quantization config.
|
||||||
|
quant_config = None
|
||||||
|
if model_config.quantization is not None:
|
||||||
|
if model_class not in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
|
||||||
|
raise ValueError(
|
||||||
|
f"Quantization is not supported for {model_class}.")
|
||||||
|
quant_config = get_quant_config(model_config.quantization,
|
||||||
|
model_config.model,
|
||||||
|
model_config.download_dir)
|
||||||
|
supported_dtypes = quant_config.get_supported_act_dtypes()
|
||||||
|
if model_config.dtype not in supported_dtypes:
|
||||||
|
raise ValueError(
|
||||||
|
f"{model_config.dtype} is not supported for quantization "
|
||||||
|
f"method {model_config.quantization}. Supported dtypes: "
|
||||||
|
f"{supported_dtypes}")
|
||||||
|
|
||||||
with _set_default_torch_dtype(model_config.dtype):
|
with _set_default_torch_dtype(model_config.dtype):
|
||||||
# Create a model instance.
|
# Create a model instance.
|
||||||
# The weights will be initialized as empty tensors.
|
# The weights will be initialized as empty tensors.
|
||||||
model = model_class(model_config.hf_config)
|
if model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION:
|
||||||
|
model = model_class(model_config.hf_config, quant_config)
|
||||||
|
else:
|
||||||
|
model = model_class(model_config.hf_config)
|
||||||
if model_config.load_format == "dummy":
|
if model_config.load_format == "dummy":
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||||
|
@ -36,13 +36,15 @@ from vllm.model_executor.layers.activation import SiluAndMul
|
|||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||||
from vllm.model_executor.layers.sampler import Sampler
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
from vllm.model_executor.weight_utils import (
|
from vllm.model_executor.layers.quantized_linear import ParallelLinear
|
||||||
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab,
|
|
||||||
hf_model_weights_iterator)
|
|
||||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
VocabParallelEmbedding)
|
||||||
|
from vllm.model_executor.quantization_utils import QuantizationConfig
|
||||||
|
from vllm.model_executor.weight_utils import (
|
||||||
|
load_tensor_parallel_weights, load_padded_tensor_parallel_vocab,
|
||||||
|
hf_model_weights_iterator)
|
||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
@ -55,18 +57,21 @@ class LlamaMLP(nn.Module):
|
|||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
hidden_act: str,
|
hidden_act: str,
|
||||||
):
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gate_up_proj = ColumnParallelLinear(hidden_size,
|
self.gate_up_proj = ParallelLinear.column(hidden_size,
|
||||||
2 * intermediate_size,
|
2 * intermediate_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
gather_output=False,
|
gather_output=False,
|
||||||
perform_initialization=False)
|
perform_initialization=False,
|
||||||
self.down_proj = RowParallelLinear(intermediate_size,
|
quant_config=quant_config)
|
||||||
hidden_size,
|
self.down_proj = ParallelLinear.row(intermediate_size,
|
||||||
bias=False,
|
hidden_size,
|
||||||
input_is_parallel=True,
|
bias=False,
|
||||||
perform_initialization=False)
|
input_is_parallel=True,
|
||||||
|
perform_initialization=False,
|
||||||
|
quant_config=quant_config)
|
||||||
if hidden_act != "silu":
|
if hidden_act != "silu":
|
||||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||||
"Only silu is supported for now.")
|
"Only silu is supported for now.")
|
||||||
@ -87,7 +92,8 @@ class LlamaAttention(nn.Module):
|
|||||||
num_heads: int,
|
num_heads: int,
|
||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
rope_theta: float = 10000,
|
rope_theta: float = 10000,
|
||||||
):
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
@ -103,20 +109,22 @@ class LlamaAttention(nn.Module):
|
|||||||
self.scaling = self.head_dim**-0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
self.rope_theta = rope_theta
|
self.rope_theta = rope_theta
|
||||||
|
|
||||||
self.qkv_proj = ColumnParallelLinear(
|
self.qkv_proj = ParallelLinear.column(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
(self.total_num_heads + 2 * self.total_num_kv_heads) *
|
(self.total_num_heads + 2 * self.total_num_kv_heads) *
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
bias=False,
|
bias=False,
|
||||||
gather_output=False,
|
gather_output=False,
|
||||||
perform_initialization=False,
|
perform_initialization=False,
|
||||||
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = ParallelLinear.row(
|
||||||
self.total_num_heads * self.head_dim,
|
self.total_num_heads * self.head_dim,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
input_is_parallel=True,
|
input_is_parallel=True,
|
||||||
perform_initialization=False,
|
perform_initialization=False,
|
||||||
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
@ -144,7 +152,11 @@ class LlamaAttention(nn.Module):
|
|||||||
|
|
||||||
class LlamaDecoderLayer(nn.Module):
|
class LlamaDecoderLayer(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config: LlamaConfig):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: LlamaConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
# Requires transformers > 4.32.0
|
# Requires transformers > 4.32.0
|
||||||
@ -154,11 +166,13 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
num_kv_heads=config.num_key_value_heads,
|
num_kv_heads=config.num_key_value_heads,
|
||||||
rope_theta=rope_theta,
|
rope_theta=rope_theta,
|
||||||
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.mlp = LlamaMLP(
|
self.mlp = LlamaMLP(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
hidden_act=config.hidden_act,
|
hidden_act=config.hidden_act,
|
||||||
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||||
eps=config.rms_norm_eps)
|
eps=config.rms_norm_eps)
|
||||||
@ -195,7 +209,11 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
class LlamaModel(nn.Module):
|
class LlamaModel(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config: LlamaConfig):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: LlamaConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
@ -205,7 +223,8 @@ class LlamaModel(nn.Module):
|
|||||||
self.embed_tokens = VocabParallelEmbedding(
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
vocab_size, config.hidden_size, perform_initialization=False)
|
vocab_size, config.hidden_size, perform_initialization=False)
|
||||||
self.layers = nn.ModuleList([
|
self.layers = nn.ModuleList([
|
||||||
LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
|
LlamaDecoderLayer(config, quant_config)
|
||||||
|
for _ in range(config.num_hidden_layers)
|
||||||
])
|
])
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
@ -237,16 +256,23 @@ class LlamaModel(nn.Module):
|
|||||||
|
|
||||||
class LlamaForCausalLM(nn.Module):
|
class LlamaForCausalLM(nn.Module):
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: LlamaConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.model = LlamaModel(config)
|
self.quant_config = quant_config
|
||||||
|
self.model = LlamaModel(config, quant_config)
|
||||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||||
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
# NOTE: The LM head is not quantized.
|
||||||
vocab_size,
|
self.lm_head = ParallelLinear.column(config.hidden_size,
|
||||||
bias=False,
|
vocab_size,
|
||||||
gather_output=False,
|
bias=False,
|
||||||
perform_initialization=False)
|
gather_output=False,
|
||||||
|
perform_initialization=False,
|
||||||
|
quant_config=None)
|
||||||
self.sampler = Sampler(config.vocab_size)
|
self.sampler = Sampler(config.vocab_size)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -263,16 +289,28 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
input_metadata)
|
input_metadata)
|
||||||
return next_tokens
|
return next_tokens
|
||||||
|
|
||||||
_column_parallel_weights = [
|
_column_parallel_layers = []
|
||||||
"qkv_proj.weight", "gate_proj.weight", "up_proj.weight"
|
_row_parallel_layers = ["o_proj", "down_proj"]
|
||||||
]
|
|
||||||
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
|
|
||||||
|
|
||||||
def load_weights(self,
|
def load_weights(self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
load_format: str = "auto",
|
load_format: str = "auto",
|
||||||
revision: Optional[str] = None):
|
revision: Optional[str] = None):
|
||||||
|
if self.quant_config is None:
|
||||||
|
weight_suffixes = ["weight"]
|
||||||
|
else:
|
||||||
|
weight_suffixes = self.quant_config.get_tp_tensor_names()
|
||||||
|
|
||||||
|
column_parallel_weights: List[str] = []
|
||||||
|
for layer in self._column_parallel_layers:
|
||||||
|
for suffix in weight_suffixes:
|
||||||
|
column_parallel_weights.append(f"{layer}.{suffix}")
|
||||||
|
row_parallel_weights: List[str] = []
|
||||||
|
for layer in self._row_parallel_layers:
|
||||||
|
for suffix in weight_suffixes:
|
||||||
|
row_parallel_weights.append(f"{layer}.{suffix}")
|
||||||
|
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||||
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
q_proj_shard_size = (self.config.hidden_size // tp_size)
|
||||||
@ -293,11 +331,25 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
is_packed = False
|
||||||
|
is_transposed = False
|
||||||
|
if self.quant_config is not None:
|
||||||
|
is_packed = self.quant_config.is_packed(name)
|
||||||
|
is_transposed = self.quant_config.is_transposed(name)
|
||||||
|
if is_transposed:
|
||||||
|
loaded_weight = loaded_weight.T
|
||||||
|
|
||||||
is_attention_weight = False
|
is_attention_weight = False
|
||||||
for weight_name, shard_size, offset in attention_weight_specs:
|
for weight_name, shard_size, offset in attention_weight_specs:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
param = state_dict[name.replace(weight_name, "qkv_proj")]
|
param = state_dict[name.replace(weight_name, "qkv_proj")]
|
||||||
|
if is_transposed:
|
||||||
|
param = param.T
|
||||||
|
|
||||||
|
if is_packed:
|
||||||
|
shard_size //= self.quant_config.pack_factor
|
||||||
|
offset //= self.quant_config.pack_factor
|
||||||
|
|
||||||
loaded_weight = loaded_weight[
|
loaded_weight = loaded_weight[
|
||||||
shard_size * tensor_model_parallel_rank:shard_size *
|
shard_size * tensor_model_parallel_rank:shard_size *
|
||||||
@ -316,6 +368,9 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
param = state_dict[name.replace(weight_name, "gate_up_proj")]
|
||||||
|
if is_transposed:
|
||||||
|
param = param.T
|
||||||
|
|
||||||
shard_size = param.shape[0] // 2
|
shard_size = param.shape[0] // 2
|
||||||
loaded_weight = loaded_weight[
|
loaded_weight = loaded_weight[
|
||||||
shard_size * tensor_model_parallel_rank:shard_size *
|
shard_size * tensor_model_parallel_rank:shard_size *
|
||||||
@ -330,6 +385,8 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
param = state_dict[name]
|
param = state_dict[name]
|
||||||
|
if is_transposed:
|
||||||
|
param = param.T
|
||||||
|
|
||||||
if "embed_tokens" in name or "lm_head" in name:
|
if "embed_tokens" in name or "lm_head" in name:
|
||||||
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
load_padded_tensor_parallel_vocab(param, loaded_weight,
|
||||||
@ -337,6 +394,6 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
load_tensor_parallel_weights(param, loaded_weight, name,
|
load_tensor_parallel_weights(param, loaded_weight, name,
|
||||||
self._column_parallel_weights,
|
column_parallel_weights,
|
||||||
self._row_parallel_weights,
|
row_parallel_weights,
|
||||||
tensor_model_parallel_rank)
|
tensor_model_parallel_rank)
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
# Parts of the code here are adapted from PyTorch
|
# Parts of the code here are adapted from PyTorch
|
||||||
# repo: https://github.com/pytorch/pytorch
|
# repo: https://github.com/pytorch/pytorch
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -16,13 +16,11 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
|||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
)
|
)
|
||||||
from .mappings import (
|
from .mappings import (
|
||||||
copy_to_tensor_model_parallel_region,
|
|
||||||
gather_from_tensor_model_parallel_region,
|
gather_from_tensor_model_parallel_region,
|
||||||
reduce_from_tensor_model_parallel_region,
|
reduce_from_tensor_model_parallel_region,
|
||||||
scatter_to_tensor_model_parallel_region,
|
scatter_to_tensor_model_parallel_region,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .random import get_cuda_rng_tracker
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
divide,
|
divide,
|
||||||
VocabUtility,
|
VocabUtility,
|
||||||
@ -65,59 +63,6 @@ def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
|
|||||||
maybe_copy(attribute)
|
maybe_copy(attribute)
|
||||||
|
|
||||||
|
|
||||||
def _initialize_affine_weight_gpu(weight, init_method,
|
|
||||||
partition_dim, stride=1):
|
|
||||||
"""Initialize affine weight for model parallel on GPU."""
|
|
||||||
|
|
||||||
set_tensor_model_parallel_attributes(tensor=weight,
|
|
||||||
is_parallel=True,
|
|
||||||
dim=partition_dim,
|
|
||||||
stride=stride)
|
|
||||||
|
|
||||||
with get_cuda_rng_tracker().fork():
|
|
||||||
init_method(weight)
|
|
||||||
|
|
||||||
|
|
||||||
def _initialize_affine_weight_cpu(weight, output_size, input_size,
|
|
||||||
per_partition_size, partition_dim,
|
|
||||||
init_method, stride=1,
|
|
||||||
return_master_weight=False,
|
|
||||||
*, params_dtype=None):
|
|
||||||
"""Initialize affine weight for model parallel.
|
|
||||||
|
|
||||||
Build the master weight on all processes and scatter
|
|
||||||
the relevant chunk."""
|
|
||||||
|
|
||||||
set_tensor_model_parallel_attributes(tensor=weight,
|
|
||||||
is_parallel=True,
|
|
||||||
dim=partition_dim,
|
|
||||||
stride=stride)
|
|
||||||
|
|
||||||
if params_dtype is None:
|
|
||||||
params_dtype = torch.get_default_dtype()
|
|
||||||
|
|
||||||
# Initialize master weight
|
|
||||||
master_weight = torch.empty(output_size, input_size,
|
|
||||||
dtype=torch.float,
|
|
||||||
requires_grad=False)
|
|
||||||
init_method(master_weight)
|
|
||||||
master_weight = master_weight.to(dtype=params_dtype)
|
|
||||||
|
|
||||||
# Split and copy
|
|
||||||
per_partition_per_stride_size = divide(per_partition_size, stride)
|
|
||||||
weight_list = torch.split(master_weight, per_partition_per_stride_size,
|
|
||||||
dim=partition_dim)
|
|
||||||
rank = get_tensor_model_parallel_rank()
|
|
||||||
world_size = get_tensor_model_parallel_world_size()
|
|
||||||
my_weight_list = weight_list[rank::world_size]
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
torch.cat(my_weight_list, dim=partition_dim, out=weight)
|
|
||||||
if return_master_weight:
|
|
||||||
return master_weight
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class VocabParallelEmbedding(torch.nn.Module):
|
class VocabParallelEmbedding(torch.nn.Module):
|
||||||
"""Embedding parallelized in the vocabulary dimension.
|
"""Embedding parallelized in the vocabulary dimension.
|
||||||
|
|
||||||
@ -140,6 +85,9 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|||||||
use_cpu_initialization: bool=False,
|
use_cpu_initialization: bool=False,
|
||||||
perform_initialization: bool=True):
|
perform_initialization: bool=True):
|
||||||
super(VocabParallelEmbedding, self).__init__()
|
super(VocabParallelEmbedding, self).__init__()
|
||||||
|
assert not perform_initialization
|
||||||
|
assert not use_cpu_initialization
|
||||||
|
|
||||||
# Keep the input dimensions.
|
# Keep the input dimensions.
|
||||||
self.num_embeddings = num_embeddings
|
self.num_embeddings = num_embeddings
|
||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
@ -162,24 +110,10 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|||||||
self.num_embeddings_per_partition = self.vocab_end_index - \
|
self.num_embeddings_per_partition = self.vocab_end_index - \
|
||||||
self.vocab_start_index
|
self.vocab_start_index
|
||||||
|
|
||||||
# Allocate weights and initialize.
|
self.weight = Parameter(torch.empty(
|
||||||
if use_cpu_initialization:
|
self.num_embeddings_per_partition, self.embedding_dim,
|
||||||
self.weight = Parameter(torch.empty(
|
device=torch.cuda.current_device(), dtype=params_dtype))
|
||||||
self.num_embeddings_per_partition, self.embedding_dim,
|
|
||||||
dtype=params_dtype))
|
|
||||||
if perform_initialization:
|
|
||||||
_initialize_affine_weight_cpu(
|
|
||||||
self.weight, self.num_embeddings, self.embedding_dim,
|
|
||||||
self.num_embeddings_per_partition, 0, init_method,
|
|
||||||
params_dtype=params_dtype)
|
|
||||||
else:
|
|
||||||
self.weight = Parameter(torch.empty(
|
|
||||||
self.num_embeddings_per_partition, self.embedding_dim,
|
|
||||||
device=torch.cuda.current_device(), dtype=params_dtype))
|
|
||||||
if perform_initialization:
|
|
||||||
_initialize_affine_weight_gpu(self.weight, init_method,
|
|
||||||
partition_dim=0, stride=1)
|
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_):
|
||||||
if self.tensor_model_parallel_size > 1:
|
if self.tensor_model_parallel_size > 1:
|
||||||
# Build the mask.
|
# Build the mask.
|
||||||
@ -239,8 +173,11 @@ class ColumnParallelLinear(torch.nn.Module):
|
|||||||
params_dtype=None,
|
params_dtype=None,
|
||||||
use_cpu_initialization=False,
|
use_cpu_initialization=False,
|
||||||
perform_initialization=True,
|
perform_initialization=True,
|
||||||
|
quant_config=None,
|
||||||
):
|
):
|
||||||
super(ColumnParallelLinear, self).__init__()
|
super(ColumnParallelLinear, self).__init__()
|
||||||
|
assert not perform_initialization
|
||||||
|
assert not use_cpu_initialization
|
||||||
|
|
||||||
# Keep input parameters
|
# Keep input parameters
|
||||||
self.input_size = input_size
|
self.input_size = input_size
|
||||||
@ -250,6 +187,7 @@ class ColumnParallelLinear(torch.nn.Module):
|
|||||||
self.world_size = get_tensor_model_parallel_world_size()
|
self.world_size = get_tensor_model_parallel_world_size()
|
||||||
self.output_size_per_partition = divide(output_size, self.world_size)
|
self.output_size_per_partition = divide(output_size, self.world_size)
|
||||||
self.skip_bias_add = skip_bias_add
|
self.skip_bias_add = skip_bias_add
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
if params_dtype is None:
|
if params_dtype is None:
|
||||||
params_dtype = torch.get_default_dtype()
|
params_dtype = torch.get_default_dtype()
|
||||||
@ -257,33 +195,13 @@ class ColumnParallelLinear(torch.nn.Module):
|
|||||||
# Parameters.
|
# Parameters.
|
||||||
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
||||||
# we allocate the transpose.
|
# we allocate the transpose.
|
||||||
# Initialize weight.
|
self.create_weights(params_dtype)
|
||||||
if use_cpu_initialization:
|
|
||||||
self.weight = Parameter(torch.empty(self.output_size_per_partition,
|
|
||||||
self.input_size,
|
|
||||||
dtype=params_dtype))
|
|
||||||
if perform_initialization:
|
|
||||||
self.master_weight = _initialize_affine_weight_cpu(
|
|
||||||
self.weight, self.output_size, self.input_size,
|
|
||||||
self.output_size_per_partition, 0, init_method,
|
|
||||||
stride=stride, return_master_weight=keep_master_weight_for_test)
|
|
||||||
else:
|
|
||||||
self.weight = Parameter(torch.empty(
|
|
||||||
self.output_size_per_partition, self.input_size,
|
|
||||||
device=torch.cuda.current_device(), dtype=params_dtype))
|
|
||||||
if perform_initialization:
|
|
||||||
_initialize_affine_weight_gpu(self.weight, init_method,
|
|
||||||
partition_dim=0, stride=stride)
|
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
if use_cpu_initialization:
|
self.bias = Parameter(torch.empty(
|
||||||
self.bias = Parameter(torch.empty(
|
self.output_size_per_partition,
|
||||||
self.output_size_per_partition, dtype=params_dtype))
|
device=torch.cuda.current_device(),
|
||||||
else:
|
dtype=params_dtype))
|
||||||
self.bias = Parameter(torch.empty(
|
|
||||||
self.output_size_per_partition,
|
|
||||||
device=torch.cuda.current_device(),
|
|
||||||
dtype=params_dtype))
|
|
||||||
set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
|
set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
|
||||||
# Always initialize bias to zero.
|
# Always initialize bias to zero.
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -291,6 +209,17 @@ class ColumnParallelLinear(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.register_parameter('bias', None)
|
self.register_parameter('bias', None)
|
||||||
|
|
||||||
|
def create_weights(self, dtype: torch.dtype) -> None:
|
||||||
|
self.weight = Parameter(torch.empty(
|
||||||
|
self.output_size_per_partition, self.input_size,
|
||||||
|
device=torch.cuda.current_device(), dtype=dtype))
|
||||||
|
|
||||||
|
def apply_weights(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return F.linear(x, self.weight, bias)
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_):
|
||||||
"""Forward of ColumnParallelLinear
|
"""Forward of ColumnParallelLinear
|
||||||
@ -306,7 +235,7 @@ class ColumnParallelLinear(torch.nn.Module):
|
|||||||
|
|
||||||
input_parallel = input_
|
input_parallel = input_
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
output_parallel = F.linear(input_parallel, self.weight, bias)
|
output_parallel = self.apply_weights(input_parallel, bias)
|
||||||
if self.gather_output:
|
if self.gather_output:
|
||||||
# All-gather across the partitions.
|
# All-gather across the partitions.
|
||||||
output = gather_from_tensor_model_parallel_region(output_parallel)
|
output = gather_from_tensor_model_parallel_region(output_parallel)
|
||||||
@ -361,8 +290,11 @@ class RowParallelLinear(torch.nn.Module):
|
|||||||
use_cpu_initialization=False,
|
use_cpu_initialization=False,
|
||||||
perform_initialization=True,
|
perform_initialization=True,
|
||||||
reduce_results=True,
|
reduce_results=True,
|
||||||
|
quant_config=None,
|
||||||
):
|
):
|
||||||
super(RowParallelLinear, self).__init__()
|
super(RowParallelLinear, self).__init__()
|
||||||
|
assert not perform_initialization
|
||||||
|
assert not use_cpu_initialization
|
||||||
|
|
||||||
# Keep input parameters
|
# Keep input parameters
|
||||||
self.input_size = input_size
|
self.input_size = input_size
|
||||||
@ -376,47 +308,32 @@ class RowParallelLinear(torch.nn.Module):
|
|||||||
self.world_size = get_tensor_model_parallel_world_size()
|
self.world_size = get_tensor_model_parallel_world_size()
|
||||||
self.input_size_per_partition = divide(input_size, self.world_size)
|
self.input_size_per_partition = divide(input_size, self.world_size)
|
||||||
self.skip_bias_add = skip_bias_add
|
self.skip_bias_add = skip_bias_add
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
self.create_weights(params_dtype)
|
||||||
|
|
||||||
if not reduce_results and (bias and not skip_bias_add):
|
if not reduce_results and (bias and not skip_bias_add):
|
||||||
raise ValueError("When not reduce the results, adding bias to the "
|
raise ValueError("When not reduce the results, adding bias to the "
|
||||||
"results can lead to incorrect results")
|
"results can lead to incorrect results")
|
||||||
|
|
||||||
# Parameters.
|
|
||||||
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
|
||||||
# we allocate the transpose.
|
|
||||||
# Initialize weight.
|
|
||||||
if use_cpu_initialization:
|
|
||||||
self.weight = Parameter(torch.empty(self.output_size,
|
|
||||||
self.input_size_per_partition,
|
|
||||||
dtype=params_dtype))
|
|
||||||
if perform_initialization:
|
|
||||||
self.master_weight = _initialize_affine_weight_cpu(
|
|
||||||
self.weight, self.output_size, self.input_size,
|
|
||||||
self.input_size_per_partition, 1, init_method,
|
|
||||||
stride=stride, return_master_weight=keep_master_weight_for_test,
|
|
||||||
params_dtype=params_dtype)
|
|
||||||
else:
|
|
||||||
self.weight = Parameter(torch.empty(
|
|
||||||
self.output_size, self.input_size_per_partition,
|
|
||||||
device=torch.cuda.current_device(), dtype=params_dtype))
|
|
||||||
if perform_initialization:
|
|
||||||
_initialize_affine_weight_gpu(self.weight, init_method,
|
|
||||||
partition_dim=1, stride=stride)
|
|
||||||
if bias:
|
if bias:
|
||||||
if use_cpu_initialization:
|
self.bias = Parameter(torch.empty(
|
||||||
self.bias = Parameter(torch.empty(self.output_size,
|
self.output_size, device=torch.cuda.current_device(),
|
||||||
dtype=params_dtype))
|
dtype=params_dtype))
|
||||||
else:
|
|
||||||
self.bias = Parameter(torch.empty(
|
|
||||||
self.output_size, device=torch.cuda.current_device(),
|
|
||||||
dtype=params_dtype))
|
|
||||||
|
|
||||||
# Always initialize bias to zero.
|
# Always initialize bias to zero.
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
self.bias.zero_()
|
self.bias.zero_()
|
||||||
else:
|
else:
|
||||||
self.register_parameter('bias', None)
|
self.register_parameter('bias', None)
|
||||||
self.weight_t = self.weight.t()
|
|
||||||
|
def create_weights(self, dtype: torch.dtype) -> None:
|
||||||
|
self.weight = Parameter(torch.empty(
|
||||||
|
self.output_size, self.input_size_per_partition,
|
||||||
|
device=torch.cuda.current_device(), dtype=dtype))
|
||||||
|
|
||||||
|
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return F.linear(x, self.weight)
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_):
|
||||||
"""Forward of RowParallelLinear
|
"""Forward of RowParallelLinear
|
||||||
@ -434,7 +351,7 @@ class RowParallelLinear(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
input_parallel = scatter_to_tensor_model_parallel_region(input_)
|
input_parallel = scatter_to_tensor_model_parallel_region(input_)
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
output_parallel = F.linear(input_parallel, self.weight)
|
output_parallel = self.apply_weights(input_parallel)
|
||||||
if self.reduce_results and self.world_size > 1:
|
if self.reduce_results and self.world_size > 1:
|
||||||
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
|
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
|
||||||
else:
|
else:
|
||||||
|
20
vllm/model_executor/quantization_utils/__init__.py
Normal file
20
vllm/model_executor/quantization_utils/__init__.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from typing import Type
|
||||||
|
|
||||||
|
from vllm.model_executor.quantization_utils.awq import AWQConfig
|
||||||
|
from vllm.model_executor.quantization_utils.base import QuantizationConfig
|
||||||
|
|
||||||
|
_QUANTIZATION_REGISTRY = {
|
||||||
|
"awq": AWQConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_quant_class(quantization: str) -> Type[QuantizationConfig]:
|
||||||
|
if quantization not in _QUANTIZATION_REGISTRY:
|
||||||
|
raise ValueError(f"Invalid quantization method: {quantization}")
|
||||||
|
return _QUANTIZATION_REGISTRY[quantization]
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"QuantizationConfig",
|
||||||
|
"get_quant_class",
|
||||||
|
]
|
67
vllm/model_executor/quantization_utils/awq.py
Normal file
67
vllm/model_executor/quantization_utils/awq.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.quantization_utils.base import QuantizationConfig
|
||||||
|
|
||||||
|
|
||||||
|
class AWQConfig(QuantizationConfig):
|
||||||
|
"""Config class for AWQ.
|
||||||
|
|
||||||
|
Reference: https://arxiv.org/abs/2306.00978
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight_bits: int,
|
||||||
|
group_size: int,
|
||||||
|
zero_point: bool,
|
||||||
|
) -> None:
|
||||||
|
self.weight_bits = weight_bits
|
||||||
|
self.group_size = group_size
|
||||||
|
self.zero_point = zero_point
|
||||||
|
|
||||||
|
if self.weight_bits != 4:
|
||||||
|
raise ValueError(
|
||||||
|
"Currently, only 4-bit weight quantization is supported for "
|
||||||
|
f"AWQ, but got {self.weight_bits} bits.")
|
||||||
|
self.pack_factor = 32 // self.weight_bits
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (f"AWQConfig(weight_bits={self.weight_bits}, "
|
||||||
|
f"group_size={self.group_size}, "
|
||||||
|
f"zero_point={self.zero_point})")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_name(cls) -> str:
|
||||||
|
return "awq"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||||
|
return [torch.half]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config_filenames(cls) -> List[str]:
|
||||||
|
return [
|
||||||
|
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
|
||||||
|
"quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq # pylint: disable=line-too-long
|
||||||
|
]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
|
||||||
|
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
||||||
|
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
|
||||||
|
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||||
|
return cls(weight_bits, group_size, zero_point)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_packed_tensor_names(cls) -> List[str]:
|
||||||
|
return ["qweight", "qzeros"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_transposed_tensor_names(cls) -> List[str]:
|
||||||
|
return ["qweight", "qzeros", "scales"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_tp_tensor_names(cls) -> List[str]:
|
||||||
|
return ["qweight", "qzeros", "scales"]
|
65
vllm/model_executor/quantization_utils/base.py
Normal file
65
vllm/model_executor/quantization_utils/base.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizationConfig:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_name(cls) -> str:
|
||||||
|
"""Name of the quantization method."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||||
|
"""List of supported activation dtypes."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config_filenames(cls) -> List[str]:
|
||||||
|
"""List of filenames to search for in the model directory."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
|
||||||
|
"""Create a config class from the model's quantization config."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
|
||||||
|
"""Get a value from the model's quantization config."""
|
||||||
|
for key in keys:
|
||||||
|
if key in config:
|
||||||
|
return config[key]
|
||||||
|
raise ValueError(f"Cannot find any of {keys} in the model's "
|
||||||
|
"quantization config.")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_packed_tensor_names(cls) -> List[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_packed(cls, tensor_name: str) -> bool:
|
||||||
|
"""Returns True if a tensor is packed.
|
||||||
|
|
||||||
|
A tensor is considered packed if each element in the tensor is a
|
||||||
|
packed representation of multiple elements in the original tensor.
|
||||||
|
For example, an INT32 element in the tensor may represent 8 INT4
|
||||||
|
elements in the original tensor.
|
||||||
|
"""
|
||||||
|
return any(tag in tensor_name for tag in cls.get_packed_tensor_names())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_transposed_tensor_names(cls) -> List[str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_transposed(cls, tensor_name: str) -> bool:
|
||||||
|
"""Returns True if a tensor is transposed relative to nn.Linear.weight.
|
||||||
|
"""
|
||||||
|
return any(tag in tensor_name
|
||||||
|
for tag in cls.get_transposed_tensor_names())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_tp_tensor_names(cls) -> List[str]:
|
||||||
|
raise NotImplementedError
|
@ -4,7 +4,7 @@ import glob
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Iterator, List, Optional, Tuple, Any
|
from typing import Any, Iterator, List, Optional, Tuple
|
||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from safetensors.torch import load_file, save_file, safe_open
|
from safetensors.torch import load_file, save_file, safe_open
|
||||||
@ -13,6 +13,8 @@ import torch
|
|||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.quantization_utils import get_quant_class
|
||||||
|
from vllm.model_executor.quantization_utils.base import QuantizationConfig
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -44,7 +46,7 @@ def _shared_pointers(tensors):
|
|||||||
def convert_bin_to_safetensor_file(
|
def convert_bin_to_safetensor_file(
|
||||||
pt_filename: str,
|
pt_filename: str,
|
||||||
sf_filename: str,
|
sf_filename: str,
|
||||||
):
|
) -> None:
|
||||||
loaded = torch.load(pt_filename, map_location="cpu")
|
loaded = torch.load(pt_filename, map_location="cpu")
|
||||||
if "state_dict" in loaded:
|
if "state_dict" in loaded:
|
||||||
loaded = loaded["state_dict"]
|
loaded = loaded["state_dict"]
|
||||||
@ -78,16 +80,55 @@ def convert_bin_to_safetensor_file(
|
|||||||
raise RuntimeError(f"The output tensors do not match for key {k}")
|
raise RuntimeError(f"The output tensors do not match for key {k}")
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(woosuk): Move this to other place.
|
||||||
|
def get_quant_config(
|
||||||
|
quantization: str,
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
) -> QuantizationConfig:
|
||||||
|
is_local = os.path.isdir(model_name_or_path)
|
||||||
|
if not is_local:
|
||||||
|
# Download the config files.
|
||||||
|
with get_lock(model_name_or_path, cache_dir):
|
||||||
|
hf_folder = snapshot_download(model_name_or_path,
|
||||||
|
allow_patterns="*.json",
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
tqdm_class=Disabledtqdm)
|
||||||
|
else:
|
||||||
|
hf_folder = model_name_or_path
|
||||||
|
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
|
||||||
|
|
||||||
|
quant_cls = get_quant_class(quantization)
|
||||||
|
quant_config_files = [
|
||||||
|
f for f in config_files if any(
|
||||||
|
f.endswith(x) for x in quant_cls.get_config_filenames())
|
||||||
|
]
|
||||||
|
if len(quant_config_files) == 0:
|
||||||
|
raise ValueError(f"Cannot find the config file for {quantization}")
|
||||||
|
if len(quant_config_files) > 1:
|
||||||
|
raise ValueError(f"Found multiple config files for {quantization}: "
|
||||||
|
f"{quant_config_files}")
|
||||||
|
|
||||||
|
quant_config_file = quant_config_files[0]
|
||||||
|
with open(quant_config_file, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
return quant_cls.from_config(config)
|
||||||
|
|
||||||
|
|
||||||
def prepare_hf_model_weights(
|
def prepare_hf_model_weights(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
cache_dir: Optional[str] = None,
|
cache_dir: Optional[str] = None,
|
||||||
use_safetensors: bool = False,
|
use_safetensors: bool = False,
|
||||||
fall_back_to_pt: bool = True,
|
fall_back_to_pt: bool = True,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
):
|
) -> Tuple[str, List[str], bool]:
|
||||||
# Download model weights from huggingface.
|
# Download model weights from huggingface.
|
||||||
is_local = os.path.isdir(model_name_or_path)
|
is_local = os.path.isdir(model_name_or_path)
|
||||||
allow_patterns = "*.safetensors" if use_safetensors else "*.bin"
|
if use_safetensors:
|
||||||
|
allow_patterns = ["*.safetensors"]
|
||||||
|
else:
|
||||||
|
# Some quantized models use .pt files for storing the weights.
|
||||||
|
allow_patterns = ["*.bin", "*.pt"]
|
||||||
if not is_local:
|
if not is_local:
|
||||||
# Use file lock to prevent multiple processes from
|
# Use file lock to prevent multiple processes from
|
||||||
# downloading the same model weights at the same time.
|
# downloading the same model weights at the same time.
|
||||||
@ -99,7 +140,9 @@ def prepare_hf_model_weights(
|
|||||||
revision=revision)
|
revision=revision)
|
||||||
else:
|
else:
|
||||||
hf_folder = model_name_or_path
|
hf_folder = model_name_or_path
|
||||||
hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns))
|
hf_weights_files: List[str] = []
|
||||||
|
for pattern in allow_patterns:
|
||||||
|
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
||||||
if not use_safetensors:
|
if not use_safetensors:
|
||||||
hf_weights_files = [
|
hf_weights_files = [
|
||||||
x for x in hf_weights_files if not x.endswith("training_args.bin")
|
x for x in hf_weights_files if not x.endswith("training_args.bin")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user