Integrate Marlin Kernels for Int4 GPTQ inference (#2497)
Co-authored-by: Robert Shaw <114415538+rib-2@users.noreply.github.com> Co-authored-by: alexm <alexm@neuralmagic.com>
This commit is contained in:
parent
90fbf12540
commit
c0c2335ce0
@ -84,6 +84,15 @@ torch::Tensor awq_dequantize(
|
|||||||
int split_k_iters,
|
int split_k_iters,
|
||||||
int thx,
|
int thx,
|
||||||
int thy);
|
int thy);
|
||||||
|
|
||||||
|
torch::Tensor marlin_gemm(
|
||||||
|
torch::Tensor& a,
|
||||||
|
torch::Tensor& b_q_weight,
|
||||||
|
torch::Tensor& b_scales,
|
||||||
|
torch::Tensor& workspace,
|
||||||
|
int64_t size_m,
|
||||||
|
int64_t size_n,
|
||||||
|
int64_t size_k);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void squeezellm_gemm(
|
void squeezellm_gemm(
|
||||||
|
@ -52,11 +52,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
&rotary_embedding,
|
&rotary_embedding,
|
||||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
||||||
|
|
||||||
// Quantization ops
|
// Quantization ops
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
||||||
|
ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
|
||||||
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
|
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
|
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
|
||||||
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
|
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
|
||||||
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
||||||
|
209
csrc/quantization/marlin/LICENSE
Normal file
209
csrc/quantization/marlin/LICENSE
Normal file
@ -0,0 +1,209 @@
|
|||||||
|
Contains code from https://github.com/IST-DASLab/marlin
|
||||||
|
|
||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "{}"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright {yyyy} {name of copyright owner}
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
|
||||||
|
------------------------------------------------------------------------------------
|
||||||
|
|
||||||
|
This product bundles various third-party components under other open source licenses.
|
||||||
|
This section summarizes those components and their licenses. See licenses/
|
||||||
|
for text of these licenses.
|
1145
csrc/quantization/marlin/marlin_cuda_kernel.cu
Normal file
1145
csrc/quantization/marlin/marlin_cuda_kernel.cu
Normal file
File diff suppressed because it is too large
Load Diff
@ -15,6 +15,7 @@ types-setuptools
|
|||||||
pytest
|
pytest
|
||||||
pytest-forked
|
pytest-forked
|
||||||
pytest-asyncio
|
pytest-asyncio
|
||||||
|
pytest-rerunfailures
|
||||||
httpx
|
httpx
|
||||||
einops # required for MPT
|
einops # required for MPT
|
||||||
openai
|
openai
|
||||||
|
2
setup.py
2
setup.py
@ -342,6 +342,8 @@ vllm_extension_sources = [
|
|||||||
|
|
||||||
if _is_cuda():
|
if _is_cuda():
|
||||||
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")
|
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")
|
||||||
|
vllm_extension_sources.append(
|
||||||
|
"csrc/quantization/marlin/marlin_cuda_kernel.cu")
|
||||||
vllm_extension_sources.append("csrc/custom_all_reduce.cu")
|
vllm_extension_sources.append("csrc/custom_all_reduce.cu")
|
||||||
|
|
||||||
# Add MoE kernels.
|
# Add MoE kernels.
|
||||||
|
@ -199,6 +199,24 @@ class VllmRunner:
|
|||||||
outputs.append((req_sample_output_ids, req_sample_output_strs))
|
outputs.append((req_sample_output_ids, req_sample_output_strs))
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
def generate_w_logprobs(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
sampling_params: SamplingParams,
|
||||||
|
) -> List[Tuple[List[int], str]]:
|
||||||
|
assert sampling_params.logprobs is not None
|
||||||
|
|
||||||
|
req_outputs = self.model.generate(prompts,
|
||||||
|
sampling_params=sampling_params)
|
||||||
|
outputs = []
|
||||||
|
for req_output in req_outputs:
|
||||||
|
for sample in req_output.outputs:
|
||||||
|
output_str = sample.text
|
||||||
|
output_ids = sample.token_ids
|
||||||
|
output_logprobs = sample.logprobs
|
||||||
|
outputs.append((output_ids, output_str, output_logprobs))
|
||||||
|
return outputs
|
||||||
|
|
||||||
def generate_greedy(
|
def generate_greedy(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
@ -209,6 +227,20 @@ class VllmRunner:
|
|||||||
return [(output_ids[0], output_str[0])
|
return [(output_ids[0], output_str[0])
|
||||||
for output_ids, output_str in outputs]
|
for output_ids, output_str in outputs]
|
||||||
|
|
||||||
|
def generate_greedy_logprobs(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
) -> List[Tuple[List[int], str]]:
|
||||||
|
greedy_logprobs_params = SamplingParams(temperature=0.0,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
logprobs=num_logprobs)
|
||||||
|
outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params)
|
||||||
|
|
||||||
|
return [(output_ids, output_str, output_logprobs)
|
||||||
|
for output_ids, output_str, output_logprobs in outputs]
|
||||||
|
|
||||||
def generate_beam_search(
|
def generate_beam_search(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
|
97
tests/models/test_marlin.py
Normal file
97
tests/models/test_marlin.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
"""Compare the outputs of a GPTQ model to a Marlin model.
|
||||||
|
|
||||||
|
Note: GPTQ and Marlin do not have bitwise correctness.
|
||||||
|
As a result, in this test, we just confirm that the top selected tokens of the
|
||||||
|
Marlin/GPTQ models are in the top 3 selections of each other.
|
||||||
|
|
||||||
|
Note: Marlin internally uses locks to synchronize the threads. This can
|
||||||
|
result in very slight nondeterminism for Marlin. As a result, we re-run the test
|
||||||
|
up to 3 times to see if we pass.
|
||||||
|
|
||||||
|
Run `pytest tests/models/test_marlin.py --forked`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from vllm.model_executor.layers.quantization import _QUANTIZATION_CONFIG_REGISTRY
|
||||||
|
|
||||||
|
capability = torch.cuda.get_device_capability()
|
||||||
|
capability = capability[0] * 10 + capability[1]
|
||||||
|
marlin_not_supported = (
|
||||||
|
capability < _QUANTIZATION_CONFIG_REGISTRY["marlin"].get_min_capability())
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelPair:
|
||||||
|
model_marlin: str
|
||||||
|
model_gptq: str
|
||||||
|
|
||||||
|
|
||||||
|
model_pairs = [
|
||||||
|
ModelPair(model_marlin="nm-testing/zephyr-beta-7b-marlin-g128",
|
||||||
|
model_gptq="nm-testing/zephyr-beta-7b-gptq-g128"),
|
||||||
|
ModelPair(model_marlin="robertgshaw2/zephyr-7b-beta-channelwise-marlin",
|
||||||
|
model_gptq="robertgshaw2/zephyr-7b-beta-channelwise-gptq"),
|
||||||
|
ModelPair(model_marlin="robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin",
|
||||||
|
model_gptq="robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-gptq")
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.flaky(reruns=2)
|
||||||
|
@pytest.mark.skipif(marlin_not_supported,
|
||||||
|
reason="Marlin is not supported on this GPU type.")
|
||||||
|
@pytest.mark.parametrize("model_pair", model_pairs)
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [32])
|
||||||
|
@pytest.mark.parametrize("num_logprobs", [3])
|
||||||
|
def test_models(
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
model_pair: ModelPair,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
) -> None:
|
||||||
|
marlin_model = vllm_runner(model_pair.model_marlin, dtype=dtype)
|
||||||
|
marlin_outputs = marlin_model.generate_greedy_logprobs(
|
||||||
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
|
# Note: not sure why, but deleting just the model on Ada Lovelace
|
||||||
|
# does not free the GPU memory. On Ampere, deleting the just model
|
||||||
|
# frees the memory.
|
||||||
|
del marlin_model.model.llm_engine.driver_worker
|
||||||
|
del marlin_model
|
||||||
|
|
||||||
|
gptq_model = vllm_runner(model_pair.model_gptq, dtype=dtype)
|
||||||
|
gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts,
|
||||||
|
max_tokens,
|
||||||
|
num_logprobs)
|
||||||
|
|
||||||
|
# Note: not sure why, but deleting just the model on Ada Lovelace
|
||||||
|
# does not free the GPU memory. On Ampere, deleting the just model
|
||||||
|
# frees the memory.
|
||||||
|
del gptq_model.model.llm_engine.driver_worker
|
||||||
|
del gptq_model
|
||||||
|
|
||||||
|
# loop through the prompts
|
||||||
|
for prompt_idx in range(len(example_prompts)):
|
||||||
|
gptq_output_ids, gptq_output_str, gptq_logprobs = gptq_outputs[
|
||||||
|
prompt_idx]
|
||||||
|
marlin_output_ids, marlin_output_str, marlin_logprobs = marlin_outputs[
|
||||||
|
prompt_idx]
|
||||||
|
|
||||||
|
for idx, (gptq_output_id, marlin_output_id) in enumerate(
|
||||||
|
zip(gptq_output_ids, marlin_output_ids)):
|
||||||
|
# If sequence is not an exact match,
|
||||||
|
if marlin_output_id != gptq_output_id:
|
||||||
|
# Each predicted token must be in top 5 of the other's
|
||||||
|
assert gptq_output_id in marlin_logprobs[idx], (
|
||||||
|
f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\nMarlin:\t{marlin_output_str!r}"
|
||||||
|
)
|
||||||
|
assert marlin_output_id in gptq_logprobs[idx], (
|
||||||
|
f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\nMarlin:\t{marlin_output_str!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Break out since sequences will now diverge.
|
||||||
|
break
|
@ -155,15 +155,21 @@ class ModelConfig:
|
|||||||
self.tokenizer_mode = tokenizer_mode
|
self.tokenizer_mode = tokenizer_mode
|
||||||
|
|
||||||
def _verify_quantization(self) -> None:
|
def _verify_quantization(self) -> None:
|
||||||
supported_quantization = ["awq", "gptq", "squeezellm"]
|
supported_quantization = ["awq", "gptq", "squeezellm", "marlin"]
|
||||||
rocm_not_supported_quantization = ["awq"]
|
rocm_not_supported_quantization = ["awq", "marlin"]
|
||||||
if self.quantization is not None:
|
if self.quantization is not None:
|
||||||
self.quantization = self.quantization.lower()
|
self.quantization = self.quantization.lower()
|
||||||
|
|
||||||
# Parse quantization method from the HF model config, if available.
|
# Parse quantization method from the HF model config, if available.
|
||||||
hf_quant_config = getattr(self.hf_config, "quantization_config", None)
|
hf_quant_config = getattr(self.hf_config, "quantization_config", None)
|
||||||
if hf_quant_config is not None:
|
if hf_quant_config is not None:
|
||||||
|
|
||||||
hf_quant_method = str(hf_quant_config["quant_method"]).lower()
|
hf_quant_method = str(hf_quant_config["quant_method"]).lower()
|
||||||
|
# If the GPTQ model is serialized in marlin format, use marlin.
|
||||||
|
if (hf_quant_method == "gptq"
|
||||||
|
and "is_marlin_format" in hf_quant_config
|
||||||
|
and hf_quant_config["is_marlin_format"]):
|
||||||
|
hf_quant_method = "marlin"
|
||||||
if self.quantization is None:
|
if self.quantization is None:
|
||||||
self.quantization = hf_quant_method
|
self.quantization = hf_quant_method
|
||||||
elif self.quantization != hf_quant_method:
|
elif self.quantization != hf_quant_method:
|
||||||
@ -183,7 +189,9 @@ class ModelConfig:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{self.quantization} quantization is currently not supported "
|
f"{self.quantization} quantization is currently not supported "
|
||||||
f"in ROCm.")
|
f"in ROCm.")
|
||||||
logger.warning(f"{self.quantization} quantization is not fully "
|
if self.quantization != "marlin":
|
||||||
|
logger.warning(
|
||||||
|
f"{self.quantization} quantization is not fully "
|
||||||
"optimized yet. The speed can be slower than "
|
"optimized yet. The speed can be slower than "
|
||||||
"non-quantized models.")
|
"non-quantized models.")
|
||||||
|
|
||||||
|
@ -17,6 +17,14 @@ from vllm.logger import init_logger
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_marlin_shard(param, shard_size, shard_offset):
|
||||||
|
marlin_tile_size = getattr(param, "marlin_tile_size", None)
|
||||||
|
if marlin_tile_size is None:
|
||||||
|
return shard_size, shard_offset
|
||||||
|
|
||||||
|
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
|
||||||
|
|
||||||
|
|
||||||
class LinearMethodBase(ABC):
|
class LinearMethodBase(ABC):
|
||||||
"""Base class for different (maybe quantized) linear methods."""
|
"""Base class for different (maybe quantized) linear methods."""
|
||||||
|
|
||||||
@ -276,6 +284,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
if packed_dim == output_dim:
|
if packed_dim == output_dim:
|
||||||
shard_size = shard_size // param.pack_factor
|
shard_size = shard_size // param.pack_factor
|
||||||
shard_offset = shard_offset // param.pack_factor
|
shard_offset = shard_offset // param.pack_factor
|
||||||
|
|
||||||
|
# If marlin, we need to adjust the offset and size to account for the tiling.
|
||||||
|
shard_size, shard_offset = adjust_marlin_shard(
|
||||||
|
param, shard_size, shard_offset)
|
||||||
|
|
||||||
loaded_weight_shard = loaded_weight.narrow(
|
loaded_weight_shard = loaded_weight.narrow(
|
||||||
output_dim, shard_offset, shard_size)
|
output_dim, shard_offset, shard_size)
|
||||||
self.weight_loader(param, loaded_weight_shard, shard_id)
|
self.weight_loader(param, loaded_weight_shard, shard_id)
|
||||||
@ -293,6 +306,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
if packed_dim == output_dim:
|
if packed_dim == output_dim:
|
||||||
shard_size = shard_size // param.pack_factor
|
shard_size = shard_size // param.pack_factor
|
||||||
shard_offset = shard_offset // param.pack_factor
|
shard_offset = shard_offset // param.pack_factor
|
||||||
|
|
||||||
|
# If marlin, we need to adjust the offset and size to account for the tiling.
|
||||||
|
shard_size, shard_offset = adjust_marlin_shard(
|
||||||
|
param, shard_size, shard_offset)
|
||||||
|
|
||||||
param_data = param_data.narrow(output_dim, shard_offset,
|
param_data = param_data.narrow(output_dim, shard_offset,
|
||||||
shard_size)
|
shard_size)
|
||||||
start_idx = tp_rank * shard_size
|
start_idx = tp_rank * shard_size
|
||||||
@ -372,6 +390,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
loaded_shard_id: Optional[str] = None):
|
loaded_shard_id: Optional[str] = None):
|
||||||
param_data = param.data
|
param_data = param.data
|
||||||
output_dim = getattr(param, "output_dim", None)
|
output_dim = getattr(param, "output_dim", None)
|
||||||
|
|
||||||
if loaded_shard_id is None:
|
if loaded_shard_id is None:
|
||||||
# Loaded weight is already packed.
|
# Loaded weight is already packed.
|
||||||
if output_dim is None:
|
if output_dim is None:
|
||||||
@ -393,6 +412,11 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
if packed_dim == output_dim:
|
if packed_dim == output_dim:
|
||||||
shard_size = shard_size // param.pack_factor
|
shard_size = shard_size // param.pack_factor
|
||||||
shard_offset = shard_offset // param.pack_factor
|
shard_offset = shard_offset // param.pack_factor
|
||||||
|
|
||||||
|
# If marlin, we need to adjust the offset and size to account for the tiling.
|
||||||
|
shard_size, shard_offset = adjust_marlin_shard(
|
||||||
|
param, shard_size, shard_offset)
|
||||||
|
|
||||||
loaded_weight_shard = loaded_weight.narrow(
|
loaded_weight_shard = loaded_weight.narrow(
|
||||||
output_dim, shard_offset, shard_size)
|
output_dim, shard_offset, shard_size)
|
||||||
self.weight_loader(param, loaded_weight_shard, shard_id)
|
self.weight_loader(param, loaded_weight_shard, shard_id)
|
||||||
@ -417,6 +441,11 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
if packed_dim == output_dim:
|
if packed_dim == output_dim:
|
||||||
shard_size = shard_size // param.pack_factor
|
shard_size = shard_size // param.pack_factor
|
||||||
shard_offset = shard_offset // param.pack_factor
|
shard_offset = shard_offset // param.pack_factor
|
||||||
|
|
||||||
|
# If marlin, we need to adjust the offset and size to account for the tiling.
|
||||||
|
shard_size, shard_offset = adjust_marlin_shard(
|
||||||
|
param, shard_size, shard_offset)
|
||||||
|
|
||||||
param_data = param_data.narrow(output_dim, shard_offset,
|
param_data = param_data.narrow(output_dim, shard_offset,
|
||||||
shard_size)
|
shard_size)
|
||||||
if loaded_shard_id == "q":
|
if loaded_shard_id == "q":
|
||||||
|
@ -4,11 +4,13 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
|
|||||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||||
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
||||||
|
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
||||||
|
|
||||||
_QUANTIZATION_CONFIG_REGISTRY = {
|
_QUANTIZATION_CONFIG_REGISTRY = {
|
||||||
"awq": AWQConfig,
|
"awq": AWQConfig,
|
||||||
"gptq": GPTQConfig,
|
"gptq": GPTQConfig,
|
||||||
"squeezellm": SqueezeLLMConfig,
|
"squeezellm": SqueezeLLMConfig,
|
||||||
|
"marlin": MarlinConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
210
vllm/model_executor/layers/quantization/marlin.py
Normal file
210
vllm/model_executor/layers/quantization/marlin.py
Normal file
@ -0,0 +1,210 @@
|
|||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
from vllm._C import ops
|
||||||
|
from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
|
|
||||||
|
|
||||||
|
class MarlinConfig(QuantizationConfig):
|
||||||
|
"""Config class for Marlin.
|
||||||
|
|
||||||
|
Reference: https://github.com/IST-DASLab/marlin/tree/master
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
group_size: int,
|
||||||
|
) -> None:
|
||||||
|
# Group size for the quantization.
|
||||||
|
self.group_size = group_size
|
||||||
|
if self.group_size != 128 and self.group_size != -1:
|
||||||
|
raise ValueError(
|
||||||
|
"Currently, only group size 128 and -1 (channelwise) is supported for "
|
||||||
|
f"Marlin, but got group_size of {self.group_size}")
|
||||||
|
|
||||||
|
# 4 Bits packed into 32 bit datatype.
|
||||||
|
self.pack_factor = 32 // 4
|
||||||
|
|
||||||
|
# Tile size used by marlin kernels.
|
||||||
|
self.tile_size = 16
|
||||||
|
|
||||||
|
# Min out_features dim
|
||||||
|
self.min_n_threads = 64
|
||||||
|
|
||||||
|
# Min in_features dim
|
||||||
|
self.min_k_threads = 128
|
||||||
|
|
||||||
|
# Max parallel problems to solve at once (improves large batch performance)
|
||||||
|
self.max_parallel = 16
|
||||||
|
|
||||||
|
# Permutation length used by the marlin kernels.
|
||||||
|
self.perm_len = 1024
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"MarlinConfig(group_size={self.group_size}"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_name(cls) -> str:
|
||||||
|
return "marlin"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||||
|
return [torch.half]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
# Need to figure it out
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
return 80
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config_filenames(cls) -> List[str]:
|
||||||
|
return ["quantize_config.json"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
|
||||||
|
group_size = cls.get_from_keys(config, ["group_size"])
|
||||||
|
return cls(group_size)
|
||||||
|
|
||||||
|
def get_linear_method(self) -> "MarlinLinearMethod":
|
||||||
|
return MarlinLinearMethod(self)
|
||||||
|
|
||||||
|
def get_scaled_act_names(self) -> List[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class MarlinLinearMethod(LinearMethodBase):
|
||||||
|
"""Linear method for Marlin.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
quant_config: The Marlin quantization config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quant_config: MarlinConfig):
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
input_size_per_partition: int,
|
||||||
|
output_size_per_partition: int,
|
||||||
|
input_size: int,
|
||||||
|
output_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
del output_size # Unused.
|
||||||
|
|
||||||
|
if params_dtype != torch.float16:
|
||||||
|
raise ValueError(
|
||||||
|
f"The params dtype must be float16, but got {params_dtype}")
|
||||||
|
|
||||||
|
# Validate output_size_per_partition
|
||||||
|
if output_size_per_partition % self.quant_config.min_n_threads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Weight output_size_per_partition = {output_size_per_partition} is not divisible by min_n_threads = {self.quant_config.min_n_threads}."
|
||||||
|
)
|
||||||
|
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Weight output_size_per_partition = {output_size_per_partition} is not divisible by pack_factor = {self.quant_config.pack_factor}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate input_size_per_partition
|
||||||
|
if input_size_per_partition % self.quant_config.min_k_threads != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Weight input_size_per_partition = {input_size_per_partition} is not divisible by min_k_threads = {self.quant_config.min_k_threads}."
|
||||||
|
)
|
||||||
|
if self.quant_config.group_size != -1 and input_size_per_partition % self.quant_config.group_size != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Weight input_size_per_partition = f{input_size_per_partition} is not divisible by group_size = {self.quant_config.group_size}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that we have at least 4 tiles horizontally in the shard
|
||||||
|
num_tiles_per_perm = self.quant_config.perm_len // (
|
||||||
|
self.quant_config.tile_size**2)
|
||||||
|
if output_size_per_partition % num_tiles_per_perm != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Each permutation group must reside on the same gpu")
|
||||||
|
|
||||||
|
# Quantized 4Bit weights packed into Int32.
|
||||||
|
qweight = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
input_size_per_partition // self.quant_config.tile_size,
|
||||||
|
output_size_per_partition * self.quant_config.tile_size //
|
||||||
|
self.quant_config.pack_factor,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
qweight,
|
||||||
|
{
|
||||||
|
"input_dim": 0,
|
||||||
|
"output_dim": 1,
|
||||||
|
"packed_dim": 1,
|
||||||
|
"pack_factor": self.quant_config.pack_factor,
|
||||||
|
"marlin_tile_size": self.quant_config.tile_size,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine if channelwise or not
|
||||||
|
input_groups = 1 if self.quant_config.group_size == -1 else input_size_per_partition // self.quant_config.group_size
|
||||||
|
|
||||||
|
scales = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
input_groups,
|
||||||
|
output_size_per_partition,
|
||||||
|
device="cuda",
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
scales,
|
||||||
|
{
|
||||||
|
"input_dim": None if input_groups == 1 else 0,
|
||||||
|
"output_dim": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Allocate workspace (Used for internal locking mechanism)
|
||||||
|
max_workspace_size = (
|
||||||
|
output_size_per_partition //
|
||||||
|
self.quant_config.min_n_threads) * self.quant_config.max_parallel
|
||||||
|
workspace = Parameter(torch.zeros(max_workspace_size,
|
||||||
|
device="cuda",
|
||||||
|
dtype=torch.int),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"B": qweight,
|
||||||
|
"s": scales,
|
||||||
|
"workspace": workspace,
|
||||||
|
}
|
||||||
|
|
||||||
|
def apply_weights(
|
||||||
|
self,
|
||||||
|
weights: Dict[str, Any],
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qweight = weights["B"]
|
||||||
|
scales = weights["s"]
|
||||||
|
workspace = weights["workspace"]
|
||||||
|
|
||||||
|
x_2d = x.view(-1, x.shape[-1])
|
||||||
|
|
||||||
|
size_m = x_2d.shape[0]
|
||||||
|
size_k = x_2d.shape[1]
|
||||||
|
size_n = scales.shape[1]
|
||||||
|
|
||||||
|
output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m,
|
||||||
|
size_n, size_k)
|
||||||
|
|
||||||
|
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
output.add_(bias) # In-place add
|
||||||
|
|
||||||
|
return output
|
Loading…
x
Reference in New Issue
Block a user