Add GPTQ Marlin 2:4 sparse structured support (#4790)
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
parent
9216b9cc38
commit
6979ade384
@ -176,7 +176,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
list(APPEND VLLM_EXT_SRC
|
list(APPEND VLLM_EXT_SRC
|
||||||
"csrc/quantization/aqlm/gemm_kernels.cu"
|
"csrc/quantization/aqlm/gemm_kernels.cu"
|
||||||
"csrc/quantization/awq/gemm_kernels.cu"
|
"csrc/quantization/awq/gemm_kernels.cu"
|
||||||
"csrc/quantization/marlin/marlin_cuda_kernel.cu"
|
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
|
||||||
|
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
|
||||||
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||||
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
||||||
"csrc/custom_all_reduce.cu")
|
"csrc/custom_all_reduce.cu")
|
||||||
|
11
csrc/ops.h
11
csrc/ops.h
@ -125,6 +125,17 @@ torch::Tensor marlin_gemm(
|
|||||||
int64_t size_n,
|
int64_t size_n,
|
||||||
int64_t size_k);
|
int64_t size_k);
|
||||||
|
|
||||||
|
torch::Tensor gptq_marlin_24_gemm(
|
||||||
|
torch::Tensor &a,
|
||||||
|
torch::Tensor &b_q_weight,
|
||||||
|
torch::Tensor &b_meta,
|
||||||
|
torch::Tensor &b_scales,
|
||||||
|
torch::Tensor &workspace,
|
||||||
|
int64_t num_bits,
|
||||||
|
int64_t size_m,
|
||||||
|
int64_t size_n,
|
||||||
|
int64_t size_k);
|
||||||
|
|
||||||
torch::Tensor gptq_marlin_gemm(
|
torch::Tensor gptq_marlin_gemm(
|
||||||
torch::Tensor &a,
|
torch::Tensor &a,
|
||||||
torch::Tensor &b_q_weight,
|
torch::Tensor &b_q_weight,
|
||||||
|
@ -66,7 +66,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM");
|
ops.def("aqlm_gemm", &aqlm_gemm, "Quantized GEMM for AQLM");
|
||||||
ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
|
ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM");
|
||||||
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
||||||
ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
|
ops.def("marlin_gemm", &marlin_gemm, "Marlin (Dense) Optimized Quantized GEMM for GPTQ");
|
||||||
|
ops.def("gptq_marlin_24_gemm", &gptq_marlin_24_gemm, "Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ");
|
||||||
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ");
|
ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ");
|
||||||
ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ");
|
ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ");
|
||||||
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
|
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
|
||||||
|
203
csrc/quantization/marlin/sparse/LICENSE
Normal file
203
csrc/quantization/marlin/sparse/LICENSE
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
Contains code from https://github.com/IST-DASLab/Sparse-Marlin/
|
||||||
|
|
||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
49
csrc/quantization/marlin/sparse/common/base.h
Normal file
49
csrc/quantization/marlin/sparse/common/base.h
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
|
||||||
|
* Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
namespace marlin_24 {
|
||||||
|
|
||||||
|
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
||||||
|
|
||||||
|
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
|
||||||
|
// for instance as inputs to tensor core operations. Consequently, all
|
||||||
|
// corresponding index accesses must be compile-time constants, which is why we
|
||||||
|
// extensively use `#pragma unroll` throughout the kernel code to guarantee
|
||||||
|
// this.
|
||||||
|
template <typename T, int n> struct Vec {
|
||||||
|
T elems[n];
|
||||||
|
__device__ T &operator[](int i) { return elems[i]; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <int M_, int N_, int K_> struct ShapeBase {
|
||||||
|
static constexpr int M = M_, N = N_, K = K_;
|
||||||
|
};
|
||||||
|
|
||||||
|
using I4 = Vec<int, 4>;
|
||||||
|
|
||||||
|
// Matrix fragments for tensor core instructions; their precise layout is
|
||||||
|
// documented here:
|
||||||
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type
|
||||||
|
using FragA = Vec<half2, 4>;
|
||||||
|
using FragB = Vec<half2, 2>;
|
||||||
|
using FragM = Vec<uint, 1>;
|
||||||
|
using FragC = Vec<float, 4>;
|
||||||
|
using FragS = Vec<half2, 1>; // quantization scales
|
||||||
|
|
||||||
|
} // namespace marlin_24
|
132
csrc/quantization/marlin/sparse/common/mem.h
Normal file
132
csrc/quantization/marlin/sparse/common/mem.h
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
|
||||||
|
* Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "base.h"
|
||||||
|
|
||||||
|
namespace marlin_24 {
|
||||||
|
// Predicated asynchronous global->shared copy; used for inputs A where we apply
|
||||||
|
// predication to handle batchsizes that are not multiples of 16.
|
||||||
|
__device__ inline void cp_async4_pred_zfill(void *smem_ptr,
|
||||||
|
const void *glob_ptr,
|
||||||
|
bool pred = true,
|
||||||
|
const bool zfill = false) {
|
||||||
|
const int BYTES = 16;
|
||||||
|
int src_in_bytes = (zfill ? 0 : BYTES);
|
||||||
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
|
asm volatile("{\n"
|
||||||
|
" .reg .pred p;\n"
|
||||||
|
" setp.ne.b32 p, %0, 0;\n"
|
||||||
|
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||||
|
"}\n" ::"r"((int)pred),
|
||||||
|
"r"(smem), "l"(glob_ptr), "n"(BYTES), "r"(src_in_bytes));
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr,
|
||||||
|
bool pred = true) {
|
||||||
|
const int BYTES = 16;
|
||||||
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
|
asm volatile("{\n"
|
||||||
|
" .reg .pred p;\n"
|
||||||
|
" setp.ne.b32 p, %0, 0;\n"
|
||||||
|
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||||
|
"}\n" ::"r"((int)pred),
|
||||||
|
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Asynchronous global->shared copy
|
||||||
|
__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) {
|
||||||
|
const int BYTES = 16;
|
||||||
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
|
asm volatile("{\n"
|
||||||
|
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||||
|
"}\n" ::"r"(smem),
|
||||||
|
"l"(glob_ptr), "n"(BYTES));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Async copy fence.
|
||||||
|
__device__ inline void cp_async_fence() {
|
||||||
|
asm volatile("cp.async.commit_group;\n" ::);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait until at most `n` async copy stages are still pending.
|
||||||
|
template <int n> __device__ inline void cp_async_wait() {
|
||||||
|
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
||||||
|
// memory, directly in tensor core layout.
|
||||||
|
__device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) {
|
||||||
|
uint32_t *a = reinterpret_cast<uint32_t *>(&frag_a);
|
||||||
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
|
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
||||||
|
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
|
||||||
|
: "r"(smem));
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ inline void ldsm4_m(FragM &frag_m, const void *smem_ptr) {
|
||||||
|
uint32_t *a = reinterpret_cast<uint32_t *>(&frag_m);
|
||||||
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
|
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n"
|
||||||
|
: "=r"(a[0]), "=r"(a[1])
|
||||||
|
: "r"(smem));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
|
||||||
|
// memory, directly in tensor core layout.
|
||||||
|
__device__ inline void ldsm4_t(FragA &frag_a, const void *smem_ptr) {
|
||||||
|
uint32_t *a = reinterpret_cast<uint32_t *>(&frag_a);
|
||||||
|
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||||
|
asm volatile(
|
||||||
|
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0,%1,%2,%3}, [%4];\n"
|
||||||
|
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
|
||||||
|
: "r"(smem));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait until barrier reaches `count`, then lock for current threadblock.
|
||||||
|
__device__ inline void barrier_acquire(int *lock, int count) {
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
int state = -1;
|
||||||
|
do
|
||||||
|
// Guarantee that subsequent writes by this threadblock will be visible
|
||||||
|
// globally.
|
||||||
|
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
|
||||||
|
: "=r"(state)
|
||||||
|
: "l"(lock));
|
||||||
|
while (state != count);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release barrier and increment visitation count.
|
||||||
|
__device__ inline void barrier_release(int *lock, bool reset = false) {
|
||||||
|
__syncthreads();
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
if (reset) {
|
||||||
|
lock[0] = 0;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
int val = 1;
|
||||||
|
// Make sure that all writes since acquiring this barrier are visible
|
||||||
|
// globally, while releasing the barrier.
|
||||||
|
asm volatile("fence.acq_rel.gpu;\n");
|
||||||
|
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
|
||||||
|
:
|
||||||
|
: "l"(lock), "r"(val));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace marlin_24
|
175
csrc/quantization/marlin/sparse/common/mma.h
Normal file
175
csrc/quantization/marlin/sparse/common/mma.h
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (C) 2024 Roberto Lopez Castro (roberto.lopez.castro@udc.es). All
|
||||||
|
* Rights Reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include "base.h"
|
||||||
|
|
||||||
|
namespace marlin_24 {
|
||||||
|
|
||||||
|
// m16n8k32 sparse tensor core mma instruction with fp16 inputs and fp32
|
||||||
|
// output/accumulation.
|
||||||
|
__device__ inline void mma_sp(const FragB &a_frag0, const FragB &a_frag1,
|
||||||
|
const FragA &frag_b, FragC &frag_c, FragM &frag_m,
|
||||||
|
const int psel) {
|
||||||
|
const uint32_t *a0 = reinterpret_cast<const uint32_t *>(&a_frag0);
|
||||||
|
const uint32_t *a1 = reinterpret_cast<const uint32_t *>(&a_frag1);
|
||||||
|
const uint32_t *b = reinterpret_cast<const uint32_t *>(&frag_b);
|
||||||
|
const uint32_t *e = reinterpret_cast<const uint32_t *>(&frag_m);
|
||||||
|
float *c = reinterpret_cast<float *>(&frag_c);
|
||||||
|
if (psel == 0) {
|
||||||
|
asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||||
|
"{%12,%13,%14,%15}, %16, 0x0;\n"
|
||||||
|
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||||
|
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
|
||||||
|
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
|
||||||
|
"f"(c[2]), "f"(c[3]), "r"(e[0]));
|
||||||
|
asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||||
|
"{%12,%13,%14,%15}, %16, 0x0;\n"
|
||||||
|
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
|
||||||
|
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
|
||||||
|
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
|
||||||
|
"f"(c[6]), "f"(c[7]), "r"(e[0]));
|
||||||
|
} else {
|
||||||
|
asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||||
|
"{%12,%13,%14,%15}, %16, 0x1;\n"
|
||||||
|
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
|
||||||
|
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[0]),
|
||||||
|
"r"(b[2]), "r"(b[4]), "r"(b[6]), "f"(c[0]), "f"(c[1]),
|
||||||
|
"f"(c[2]), "f"(c[3]), "r"(e[0]));
|
||||||
|
asm volatile("mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 "
|
||||||
|
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9, %10,%11}, "
|
||||||
|
"{%12,%13,%14,%15}, %16, 0x1;\n"
|
||||||
|
: "=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7])
|
||||||
|
: "r"(a0[0]), "r"(a1[0]), "r"(a0[1]), "r"(a1[1]), "r"(b[1]),
|
||||||
|
"r"(b[3]), "r"(b[5]), "r"(b[7]), "f"(c[4]), "f"(c[5]),
|
||||||
|
"f"(c[6]), "f"(c[7]), "r"(e[0]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lookup-table based 3-input logical operation; explicitly used for
|
||||||
|
// dequantization as the compiler does not seem to automatically recognize it in
|
||||||
|
// all cases.
|
||||||
|
template <int lut> __device__ inline int lop3(int a, int b, int c) {
|
||||||
|
int res;
|
||||||
|
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
|
||||||
|
: "=r"(res)
|
||||||
|
: "r"(a), "r"(b), "r"(c), "n"(lut));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ uint2 to_half4(float c0, float c1, float c2,
|
||||||
|
float c3) {
|
||||||
|
uint2 r;
|
||||||
|
asm("{\n\t"
|
||||||
|
".reg .f16 a, b, c, d; \n\t"
|
||||||
|
"cvt.rn.f16.f32 a, %2; \n\t"
|
||||||
|
"cvt.rn.f16.f32 b, %3; \n\t"
|
||||||
|
"cvt.rn.f16.f32 c, %4; \n\t"
|
||||||
|
"cvt.rn.f16.f32 d, %5; \n\t"
|
||||||
|
"mov.b32 %0, {a, b}; \n\t"
|
||||||
|
"mov.b32 %1, {c, d}; \n\t"
|
||||||
|
"}"
|
||||||
|
: "=r"(r.x), "=r"(r.y)
|
||||||
|
: "f"(c0), "f"(c1), "f"(c2), "f"(c3));
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Constructs destination register by taking bytes from 2 sources (based on
|
||||||
|
// mask)
|
||||||
|
template <int start_byte, int mask>
|
||||||
|
__device__ inline uint32_t prmt(uint32_t a) {
|
||||||
|
uint32_t res;
|
||||||
|
asm volatile("prmt.b32 %0, %1, %2, %3;\n"
|
||||||
|
: "=r"(res)
|
||||||
|
: "r"(a), "n"(start_byte), "n"(mask));
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
|
||||||
|
// values. We mostly follow the strategy in the link below, with some small
|
||||||
|
// changes:
|
||||||
|
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||||
|
__device__ inline FragB dequant_4bit(int q) {
|
||||||
|
const int LO = 0x000f000f;
|
||||||
|
const int HI = 0x00f000f0;
|
||||||
|
const int EX = 0x64006400;
|
||||||
|
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||||
|
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||||
|
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||||
|
// We want signed int4 outputs, hence we fuse the `-8` symmetric zero point
|
||||||
|
// directly into `SUB` and `ADD`.
|
||||||
|
const int SUB = 0x64086408;
|
||||||
|
const int MUL = 0x2c002c00;
|
||||||
|
const int ADD = 0xd480d480;
|
||||||
|
|
||||||
|
FragB frag_b;
|
||||||
|
frag_b[0] = __hsub2(*reinterpret_cast<half2 *>(&lo),
|
||||||
|
*reinterpret_cast<const half2 *>(&SUB));
|
||||||
|
frag_b[1] = __hfma2(*reinterpret_cast<half2 *>(&hi),
|
||||||
|
*reinterpret_cast<const half2 *>(&MUL),
|
||||||
|
*reinterpret_cast<const half2 *>(&ADD));
|
||||||
|
return frag_b;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
|
||||||
|
// values. We mostly follow the strategy in the link below, with some small
|
||||||
|
// changes:
|
||||||
|
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||||
|
__device__ inline FragB dequant_8bit(int q) {
|
||||||
|
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||||
|
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||||
|
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||||
|
|
||||||
|
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
||||||
|
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
||||||
|
|
||||||
|
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480;
|
||||||
|
|
||||||
|
FragB frag_b;
|
||||||
|
frag_b[0] = __hsub2(*reinterpret_cast<half2 *>(&lo),
|
||||||
|
*reinterpret_cast<const half2 *>(&I8s_TO_F16s_MAGIC_NUM));
|
||||||
|
frag_b[1] = __hsub2(*reinterpret_cast<half2 *>(&hi),
|
||||||
|
*reinterpret_cast<const half2 *>(&I8s_TO_F16s_MAGIC_NUM));
|
||||||
|
return frag_b;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multiply dequantized values by the corresponding quantization scale; used
|
||||||
|
// only for grouped quantization.
|
||||||
|
__device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) {
|
||||||
|
half2 s = __half2half2(reinterpret_cast<__half *>(&frag_s)[i]);
|
||||||
|
frag_b[0] = __hmul2(frag_b[0], s);
|
||||||
|
frag_b[1] = __hmul2(frag_b[1], s);
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ inline void scale_floats(float *c0, float *c1, float *c2, float *c3,
|
||||||
|
FragS &s0, float *c4, float *c5, float *c6,
|
||||||
|
float *c7, FragS &s1) {
|
||||||
|
*c0 = __fmul_rn(*c0, __half2float(s0[0].x));
|
||||||
|
*c1 = __fmul_rn(*c1, __half2float(s0[0].y));
|
||||||
|
*c2 = __fmul_rn(*c2, __half2float(s0[1].x));
|
||||||
|
*c3 = __fmul_rn(*c3, __half2float(s0[1].y));
|
||||||
|
|
||||||
|
*c4 = __fmul_rn(*c4, __half2float(s1[0].x));
|
||||||
|
*c5 = __fmul_rn(*c5, __half2float(s1[0].y));
|
||||||
|
*c6 = __fmul_rn(*c6, __half2float(s1[1].x));
|
||||||
|
*c7 = __fmul_rn(*c7, __half2float(s1[1].y));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace marlin_24
|
1110
csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
Normal file
1110
csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
Normal file
File diff suppressed because it is too large
Load Diff
81
tests/models/test_gptq_marlin_24.py
Normal file
81
tests/models/test_gptq_marlin_24.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
"""Compare the outputs of a GPTQ model to a Marlin_24 model.
|
||||||
|
|
||||||
|
Note: GPTQ and Marlin_24 do not have bitwise correctness.
|
||||||
|
As a result, in this test, we just confirm that the top selected tokens of the
|
||||||
|
Marlin/GPTQ models are in the top 3 selections of each other.
|
||||||
|
|
||||||
|
Run `pytest tests/models/test_marlin_24.py`.
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tests.models.utils import check_logprobs_close
|
||||||
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
|
|
||||||
|
capability = torch.cuda.get_device_capability()
|
||||||
|
capability = capability[0] * 10 + capability[1]
|
||||||
|
marlin_not_supported = (capability <
|
||||||
|
QUANTIZATION_METHODS["marlin"].get_min_capability())
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelPair:
|
||||||
|
model_marlin: str
|
||||||
|
model_gptq: str
|
||||||
|
|
||||||
|
|
||||||
|
model_pairs = [
|
||||||
|
# 4-bit, group_size == 128
|
||||||
|
ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-4bit-g128",
|
||||||
|
model_gptq="alexm-nm/tinyllama-24-gptq-4bit-g128"),
|
||||||
|
# 4-bit, group_size == channelwise
|
||||||
|
ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-4bit-channelwise",
|
||||||
|
model_gptq="alexm-nm/tinyllama-24-gptq-4bit-channelwise"),
|
||||||
|
|
||||||
|
# 8-bit, group_size == 128
|
||||||
|
ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-8bit-g128",
|
||||||
|
model_gptq="alexm-nm/tinyllama-24-gptq-8bit-g128"),
|
||||||
|
# 8-bit, group_size == channelwise
|
||||||
|
ModelPair(model_marlin="alexm-nm/tinyllama-24-marlin24-8bit-channelwise",
|
||||||
|
model_gptq="alexm-nm/tinyllama-24-gptq-8bit-channelwise"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.flaky(reruns=2)
|
||||||
|
@pytest.mark.skipif(marlin_not_supported,
|
||||||
|
reason="Marlin24 is not supported on this GPU type.")
|
||||||
|
@pytest.mark.parametrize("model_pair", model_pairs)
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [8])
|
||||||
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
|
def test_models(
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
model_pair: ModelPair,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
) -> None:
|
||||||
|
marlin_24_model = vllm_runner(model_pair.model_marlin,
|
||||||
|
dtype=dtype,
|
||||||
|
quantization="gptq_marlin_24")
|
||||||
|
marlin_24_outputs = marlin_24_model.generate_greedy_logprobs(
|
||||||
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
del marlin_24_model
|
||||||
|
|
||||||
|
gptq_model = vllm_runner(model_pair.model_gptq,
|
||||||
|
dtype=dtype,
|
||||||
|
quantization="gptq")
|
||||||
|
gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts,
|
||||||
|
max_tokens,
|
||||||
|
num_logprobs)
|
||||||
|
del gptq_model
|
||||||
|
|
||||||
|
check_logprobs_close(
|
||||||
|
outputs_0_lst=gptq_outputs,
|
||||||
|
outputs_1_lst=marlin_24_outputs,
|
||||||
|
name_0="gptq",
|
||||||
|
name_1="marlin_24",
|
||||||
|
)
|
@ -153,6 +153,16 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
|||||||
size_n, size_k)
|
size_n, size_k)
|
||||||
|
|
||||||
|
|
||||||
|
# marlin_24
|
||||||
|
def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||||
|
b_meta: torch.Tensor, b_scales: torch.Tensor,
|
||||||
|
workspace: torch.Tensor, num_bits: int, size_m: int,
|
||||||
|
size_n: int, size_k: int) -> torch.Tensor:
|
||||||
|
return vllm_ops.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
|
||||||
|
workspace, num_bits, size_m, size_n,
|
||||||
|
size_k)
|
||||||
|
|
||||||
|
|
||||||
# aqlm
|
# aqlm
|
||||||
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
|
def aqlm_gemm(input: torch.Tensor, codes: torch.Tensor,
|
||||||
codebooks: torch.Tensor, scales: torch.Tensor,
|
codebooks: torch.Tensor, scales: torch.Tensor,
|
||||||
|
@ -7,14 +7,11 @@ import torch
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
get_quantization_config)
|
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
from vllm.transformers_utils.config import get_config, get_hf_text_config
|
from vllm.transformers_utils.config import get_config, get_hf_text_config
|
||||||
from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron
|
from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron
|
||||||
|
|
||||||
GPTQMarlinConfig = get_quantization_config("gptq_marlin")
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.util.placement_group import PlacementGroup
|
from ray.util.placement_group import PlacementGroup
|
||||||
|
|
||||||
@ -155,37 +152,15 @@ class ModelConfig:
|
|||||||
quant_cfg = getattr(self.hf_config, "quantization_config", None)
|
quant_cfg = getattr(self.hf_config, "quantization_config", None)
|
||||||
if quant_cfg is not None:
|
if quant_cfg is not None:
|
||||||
quant_method = quant_cfg.get("quant_method", "").lower()
|
quant_method = quant_cfg.get("quant_method", "").lower()
|
||||||
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
|
||||||
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
|
||||||
is_format_marlin = (quant_cfg.get("checkpoint_format") == "marlin"
|
|
||||||
or quant_cfg.get("is_marlin_format", False))
|
|
||||||
|
|
||||||
# Check which LinearMethod the GPTQ model should use.
|
# Detect which checkpoint is it
|
||||||
if quant_method == "gptq":
|
for name, method in QUANTIZATION_METHODS.items():
|
||||||
# If serialized in Marlin format, use MarlinLinearMethod.
|
quantization_override = method.override_quantization_method(
|
||||||
# TODO (@robertgshaw): migrate under GPTQMarlinLinearMethod.
|
quant_cfg, self.quantization)
|
||||||
if is_format_marlin:
|
if quantization_override:
|
||||||
logger.info("The model is serialized in Marlin format. "
|
quant_method = quantization_override
|
||||||
"Using Marlin kernel.")
|
self.quantization = quantization_override
|
||||||
quant_method = "marlin"
|
break
|
||||||
if self.quantization == "gptq":
|
|
||||||
self.quantization = quant_method
|
|
||||||
|
|
||||||
# If convertible to Marlin format, use GPTQMarlinLinearMethod
|
|
||||||
# unless the user explicitly specified GPTQLinearMethod.
|
|
||||||
elif GPTQMarlinConfig.is_marlin_compatible(quant_cfg):
|
|
||||||
if self.quantization == "gptq":
|
|
||||||
logger.warning(
|
|
||||||
"The model is convertible to Marlin format, but "
|
|
||||||
"you specified quantization=gptq. Use "
|
|
||||||
"quantization=marlin for faster inference.")
|
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
"The model is convertible to Marlin format. "
|
|
||||||
"Using Marlin kernel.")
|
|
||||||
quant_method = "gptq_marlin"
|
|
||||||
if self.quantization == "marlin":
|
|
||||||
self.quantization = quant_method
|
|
||||||
|
|
||||||
# Verify quantization configurations.
|
# Verify quantization configurations.
|
||||||
if self.quantization is None:
|
if self.quantization is None:
|
||||||
@ -207,7 +182,8 @@ class ModelConfig:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{self.quantization} quantization is currently not "
|
f"{self.quantization} quantization is currently not "
|
||||||
f"supported in ROCm.")
|
f"supported in ROCm.")
|
||||||
if (self.quantization not in ["marlin", "gptq_marlin"]):
|
if (self.quantization
|
||||||
|
not in ["marlin", "gptq_marlin_24", "gptq_marlin"]):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"%s quantization is not fully "
|
"%s quantization is not fully "
|
||||||
"optimized yet. The speed can be slower than "
|
"optimized yet. The speed can be slower than "
|
||||||
|
@ -10,18 +10,23 @@ from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
|||||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||||
GPTQMarlinConfig)
|
GPTQMarlinConfig)
|
||||||
|
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||||
|
GPTQMarlin24Config)
|
||||||
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
||||||
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
||||||
|
|
||||||
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||||
"aqlm": AQLMConfig,
|
"aqlm": AQLMConfig,
|
||||||
"awq": AWQConfig,
|
"awq": AWQConfig,
|
||||||
|
"deepspeedfp": DeepSpeedFPConfig,
|
||||||
"fp8": Fp8Config,
|
"fp8": Fp8Config,
|
||||||
|
# The order of gptq methods is important for config.py iteration over
|
||||||
|
# override_quantization_method(..)
|
||||||
|
"marlin": MarlinConfig,
|
||||||
|
"gptq_marlin_24": GPTQMarlin24Config,
|
||||||
|
"gptq_marlin": GPTQMarlinConfig,
|
||||||
"gptq": GPTQConfig,
|
"gptq": GPTQConfig,
|
||||||
"squeezellm": SqueezeLLMConfig,
|
"squeezellm": SqueezeLLMConfig,
|
||||||
"gptq_marlin": GPTQMarlinConfig,
|
|
||||||
"marlin": MarlinConfig,
|
|
||||||
"deepspeedfp": DeepSpeedFPConfig
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -66,6 +66,17 @@ class QuantizationConfig(ABC):
|
|||||||
"""Create a config class from the model's quantization config."""
|
"""Create a config class from the model's quantization config."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def override_quantization_method(cls, hf_quant_cfg,
|
||||||
|
user_quant) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Detects if this quantization method can support a given checkpoint
|
||||||
|
format by overriding the user specified quantization method --
|
||||||
|
this method should only be overwritten by subclasses in exceptional
|
||||||
|
circumstances
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
|
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
|
||||||
"""Get a value from the model's quantization config."""
|
"""Get a value from the model's quantization config."""
|
||||||
|
@ -6,11 +6,14 @@ import torch
|
|||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
set_weight_attrs)
|
set_weight_attrs)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
GPTQ_MARLIN_TILE = 16
|
GPTQ_MARLIN_TILE = 16
|
||||||
GPTQ_MARLIN_MIN_THREAD_N = 64
|
GPTQ_MARLIN_MIN_THREAD_N = 64
|
||||||
GPTQ_MARLIN_MIN_THREAD_K = 128
|
GPTQ_MARLIN_MIN_THREAD_K = 128
|
||||||
@ -117,6 +120,26 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|||||||
is_sym = cls.get_from_keys(config, ["sym"])
|
is_sym = cls.get_from_keys(config, ["sym"])
|
||||||
return cls(weight_bits, group_size, desc_act, is_sym)
|
return cls(weight_bits, group_size, desc_act, is_sym)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def override_quantization_method(cls, hf_quant_cfg,
|
||||||
|
user_quant) -> Optional[str]:
|
||||||
|
can_convert = cls.is_marlin_compatible(hf_quant_cfg)
|
||||||
|
|
||||||
|
is_valid_user_quant = (user_quant is None or user_quant == "marlin")
|
||||||
|
|
||||||
|
if can_convert and is_valid_user_quant:
|
||||||
|
msg = ("The model is convertible to {} during runtime."
|
||||||
|
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
|
||||||
|
logger.info(msg)
|
||||||
|
return cls.get_name()
|
||||||
|
|
||||||
|
if can_convert and user_quant == "gptq":
|
||||||
|
logger.info("Detected that the model can run with gptq_marlin"
|
||||||
|
", however you specified quantization=gptq explicitly,"
|
||||||
|
" so forcing gptq. Use quantization=gptq_marlin for"
|
||||||
|
" faster inference")
|
||||||
|
return None
|
||||||
|
|
||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self,
|
self,
|
||||||
layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]:
|
layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]:
|
||||||
|
280
vllm/model_executor/layers/quantization/gptq_marlin_24.py
Normal file
280
vllm/model_executor/layers/quantization/gptq_marlin_24.py
Normal file
@ -0,0 +1,280 @@
|
|||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
|
QuantizationConfig)
|
||||||
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class GPTQMarlin24Config(QuantizationConfig):
|
||||||
|
"""Config class for Marlin24.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight_bits: int,
|
||||||
|
group_size: int,
|
||||||
|
) -> None:
|
||||||
|
self.weight_bits = weight_bits
|
||||||
|
self.group_size = group_size
|
||||||
|
|
||||||
|
if self.weight_bits != 4 and self.weight_bits != 8:
|
||||||
|
raise ValueError("weight_bits must be 4 or 8. Got = {}".format(
|
||||||
|
self.weight_bits))
|
||||||
|
|
||||||
|
if self.group_size != 128 and self.group_size != -1:
|
||||||
|
raise ValueError(
|
||||||
|
"Currently, only group size 128 and -1 (channelwise) "
|
||||||
|
"is supported for Marlin24, but got group_size of "
|
||||||
|
f"{self.group_size}")
|
||||||
|
|
||||||
|
# 4 Bits packed into 32 bit datatype.
|
||||||
|
self.pack_factor = 32 // self.weight_bits
|
||||||
|
|
||||||
|
# Tile size used by marlin kernels.
|
||||||
|
self.tile_size = 16
|
||||||
|
|
||||||
|
# Min out_features dim
|
||||||
|
self.min_n_threads = 128
|
||||||
|
|
||||||
|
# Min in_features dim
|
||||||
|
self.min_k_threads = 128
|
||||||
|
|
||||||
|
# Max parallel problems to solve at once (improves large
|
||||||
|
# batch performance)
|
||||||
|
self.max_parallel = 16
|
||||||
|
|
||||||
|
# Permutation length used by the marlin kernels.
|
||||||
|
self.perm_len = 1024
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return "Marlin24Config(weight_bits={}, group_size={})".format(
|
||||||
|
self.weight_bits, self.group_size)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_name(cls) -> str:
|
||||||
|
return "gptq_marlin_24"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||||
|
return [torch.half]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
# Need to figure it out
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
return 80
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config_filenames(cls) -> List[str]:
|
||||||
|
return ["quantize_config.json"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlin24Config":
|
||||||
|
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||||
|
group_size = cls.get_from_keys(config, ["group_size"])
|
||||||
|
return cls(weight_bits, group_size)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def override_quantization_method(cls, hf_quant_cfg,
|
||||||
|
user_quant) -> Optional[str]:
|
||||||
|
is_marlin_24_format = (
|
||||||
|
hf_quant_cfg.get("checkpoint_format") == "marlin_24")
|
||||||
|
|
||||||
|
is_valid_user_quant = (user_quant is None or user_quant == "gptq"
|
||||||
|
or user_quant == "gptq_marlin_24")
|
||||||
|
|
||||||
|
if is_marlin_24_format and is_valid_user_quant:
|
||||||
|
msg = ("The model is serialized in {} format. "
|
||||||
|
"Using {} kernel.".format(cls.get_name(), cls.get_name()))
|
||||||
|
logger.info(msg)
|
||||||
|
return cls.get_name()
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_quant_method(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module) -> Optional["GPTQMarlin24LinearMethod"]:
|
||||||
|
if isinstance(layer, LinearBase):
|
||||||
|
return GPTQMarlin24LinearMethod(self)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_scaled_act_names(self) -> List[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class GPTQMarlin24LinearMethod(LinearMethodBase):
|
||||||
|
"""Linear method for Marlin24.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quant_config: The Marlin24 quantization config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quant_config: GPTQMarlin24Config):
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
input_size_per_partition: int,
|
||||||
|
output_partition_sizes: List[int],
|
||||||
|
input_size: int,
|
||||||
|
output_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
del output_size # Unused.
|
||||||
|
|
||||||
|
if params_dtype != torch.float16:
|
||||||
|
raise ValueError(
|
||||||
|
f"The params dtype must be float16, but got {params_dtype}")
|
||||||
|
|
||||||
|
# Validate output_size_per_partition
|
||||||
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
|
if output_size_per_partition % self.quant_config.min_n_threads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Weight output_size_per_partition = "
|
||||||
|
f"{output_size_per_partition} is not divisible by "
|
||||||
|
f"min_n_threads = {self.quant_config.min_n_threads}.")
|
||||||
|
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Weight output_size_per_partition = "
|
||||||
|
f"{output_size_per_partition} is not divisible by "
|
||||||
|
f"pack_factor = {self.quant_config.pack_factor}.")
|
||||||
|
|
||||||
|
# Validate input_size_per_partition
|
||||||
|
if input_size_per_partition % self.quant_config.min_k_threads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Weight input_size_per_partition = "
|
||||||
|
f"{input_size_per_partition} is not divisible by "
|
||||||
|
f"min_k_threads = {self.quant_config.min_k_threads}.")
|
||||||
|
if (self.quant_config.group_size != -1 and
|
||||||
|
input_size_per_partition % self.quant_config.group_size != 0):
|
||||||
|
raise ValueError(f"Weight input_size_per_partition = "
|
||||||
|
f"{input_size_per_partition} is not divisible by "
|
||||||
|
f"group_size = {self.quant_config.group_size}.")
|
||||||
|
|
||||||
|
# Check that we have at least 4 tiles horizontally in the shard
|
||||||
|
num_tiles_per_perm = self.quant_config.perm_len // (
|
||||||
|
self.quant_config.tile_size**2)
|
||||||
|
if output_size_per_partition % num_tiles_per_perm != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Each permutation group must reside on the same gpu")
|
||||||
|
|
||||||
|
# Quantized 4Bit weights packed into Int32.
|
||||||
|
qweight = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
input_size_per_partition // self.quant_config.tile_size // 2,
|
||||||
|
output_size_per_partition * self.quant_config.tile_size //
|
||||||
|
self.quant_config.pack_factor,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
qweight,
|
||||||
|
{
|
||||||
|
"input_dim": 0,
|
||||||
|
"output_dim": 1,
|
||||||
|
"packed_dim": 1,
|
||||||
|
"pack_factor": self.quant_config.pack_factor,
|
||||||
|
"marlin_tile_size": self.quant_config.tile_size,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Meta
|
||||||
|
meta = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
input_size_per_partition // 8 // 2 // 2,
|
||||||
|
output_size_per_partition * 2,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int16,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
meta,
|
||||||
|
{
|
||||||
|
"input_dim": 0,
|
||||||
|
"packed_dim": 1,
|
||||||
|
"pack_factor": 1,
|
||||||
|
"output_dim": 1,
|
||||||
|
"marlin_tile_size": 2,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine if channelwise or not
|
||||||
|
input_groups = (1 if self.quant_config.group_size == -1 else
|
||||||
|
input_size_per_partition //
|
||||||
|
self.quant_config.group_size)
|
||||||
|
|
||||||
|
scales = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
input_groups,
|
||||||
|
output_size_per_partition,
|
||||||
|
device="cuda",
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
scales,
|
||||||
|
{
|
||||||
|
"input_dim": None if input_groups == 1 else 0,
|
||||||
|
"output_dim": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Allocate workspace (Used for internal locking mechanism)
|
||||||
|
max_workspace_size = (
|
||||||
|
output_size_per_partition //
|
||||||
|
self.quant_config.min_n_threads) * self.quant_config.max_parallel
|
||||||
|
workspace = Parameter(torch.zeros(max_workspace_size,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
layer.register_parameter("B_24", qweight)
|
||||||
|
set_weight_attrs(qweight, extra_weight_attrs)
|
||||||
|
layer.register_parameter("B_meta", meta)
|
||||||
|
set_weight_attrs(meta, extra_weight_attrs)
|
||||||
|
layer.register_parameter("s", scales)
|
||||||
|
set_weight_attrs(scales, extra_weight_attrs)
|
||||||
|
layer.register_parameter("workspace", workspace)
|
||||||
|
set_weight_attrs(workspace, extra_weight_attrs)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qweight = layer.B_24
|
||||||
|
meta = layer.B_meta
|
||||||
|
scales = layer.s
|
||||||
|
workspace = layer.workspace
|
||||||
|
|
||||||
|
x_2d = x.view(-1, x.shape[-1])
|
||||||
|
|
||||||
|
size_m = x_2d.shape[0]
|
||||||
|
size_k = x_2d.shape[1]
|
||||||
|
size_n = scales.shape[1]
|
||||||
|
|
||||||
|
output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
|
||||||
|
workspace,
|
||||||
|
self.quant_config.weight_bits,
|
||||||
|
size_m, size_n, size_k)
|
||||||
|
|
||||||
|
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
output.add_(bias) # In-place add
|
||||||
|
|
||||||
|
return output
|
@ -4,11 +4,14 @@ import torch
|
|||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MarlinConfig(QuantizationConfig):
|
class MarlinConfig(QuantizationConfig):
|
||||||
"""Config class for Marlin.
|
"""Config class for Marlin.
|
||||||
@ -72,6 +75,25 @@ class MarlinConfig(QuantizationConfig):
|
|||||||
group_size = cls.get_from_keys(config, ["group_size"])
|
group_size = cls.get_from_keys(config, ["group_size"])
|
||||||
return cls(group_size)
|
return cls(group_size)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def override_quantization_method(cls, hf_quant_cfg,
|
||||||
|
user_quant) -> Optional[str]:
|
||||||
|
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
||||||
|
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
||||||
|
is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin"
|
||||||
|
or hf_quant_cfg.get("is_marlin_format", False))
|
||||||
|
|
||||||
|
is_valid_user_quant = (user_quant is None or user_quant == "gptq"
|
||||||
|
or user_quant == "marlin")
|
||||||
|
|
||||||
|
if is_marlin_format and is_valid_user_quant:
|
||||||
|
msg = ("The model is serialized in {} format. Using {} kernel.".
|
||||||
|
format(cls.get_name(), cls.get_name()))
|
||||||
|
logger.info(msg)
|
||||||
|
return cls.get_name()
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]:
|
self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]:
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user