Integrate Marlin Kernels for Int4 GPTQ inference (#2497)
Co-authored-by: Robert Shaw <114415538+rib-2@users.noreply.github.com> Co-authored-by: alexm <alexm@neuralmagic.com>
This commit is contained in:
parent
90fbf12540
commit
c0c2335ce0
@ -84,6 +84,15 @@ torch::Tensor awq_dequantize(
|
||||
int split_k_iters,
|
||||
int thx,
|
||||
int thy);
|
||||
|
||||
torch::Tensor marlin_gemm(
|
||||
torch::Tensor& a,
|
||||
torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales,
|
||||
torch::Tensor& workspace,
|
||||
int64_t size_m,
|
||||
int64_t size_n,
|
||||
int64_t size_k);
|
||||
#endif
|
||||
|
||||
void squeezellm_gemm(
|
||||
|
@ -52,11 +52,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
&rotary_embedding,
|
||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
||||
|
||||
// Quantization ops
|
||||
// Quantization ops
|
||||
#ifndef USE_ROCM
|
||||
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
||||
ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ");
|
||||
ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
|
||||
#endif
|
||||
|
||||
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
|
||||
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
|
||||
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
||||
|
209
csrc/quantization/marlin/LICENSE
Normal file
209
csrc/quantization/marlin/LICENSE
Normal file
@ -0,0 +1,209 @@
|
||||
Contains code from https://github.com/IST-DASLab/marlin
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "{}"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright {yyyy} {name of copyright owner}
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
------------------------------------------------------------------------------------
|
||||
|
||||
This product bundles various third-party components under other open source licenses.
|
||||
This section summarizes those components and their licenses. See licenses/
|
||||
for text of these licenses.
|
1145
csrc/quantization/marlin/marlin_cuda_kernel.cu
Normal file
1145
csrc/quantization/marlin/marlin_cuda_kernel.cu
Normal file
File diff suppressed because it is too large
Load Diff
@ -15,6 +15,7 @@ types-setuptools
|
||||
pytest
|
||||
pytest-forked
|
||||
pytest-asyncio
|
||||
pytest-rerunfailures
|
||||
httpx
|
||||
einops # required for MPT
|
||||
openai
|
||||
|
2
setup.py
2
setup.py
@ -342,6 +342,8 @@ vllm_extension_sources = [
|
||||
|
||||
if _is_cuda():
|
||||
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")
|
||||
vllm_extension_sources.append(
|
||||
"csrc/quantization/marlin/marlin_cuda_kernel.cu")
|
||||
vllm_extension_sources.append("csrc/custom_all_reduce.cu")
|
||||
|
||||
# Add MoE kernels.
|
||||
|
@ -199,6 +199,24 @@ class VllmRunner:
|
||||
outputs.append((req_sample_output_ids, req_sample_output_strs))
|
||||
return outputs
|
||||
|
||||
def generate_w_logprobs(
|
||||
self,
|
||||
prompts: List[str],
|
||||
sampling_params: SamplingParams,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
assert sampling_params.logprobs is not None
|
||||
|
||||
req_outputs = self.model.generate(prompts,
|
||||
sampling_params=sampling_params)
|
||||
outputs = []
|
||||
for req_output in req_outputs:
|
||||
for sample in req_output.outputs:
|
||||
output_str = sample.text
|
||||
output_ids = sample.token_ids
|
||||
output_logprobs = sample.logprobs
|
||||
outputs.append((output_ids, output_str, output_logprobs))
|
||||
return outputs
|
||||
|
||||
def generate_greedy(
|
||||
self,
|
||||
prompts: List[str],
|
||||
@ -209,6 +227,20 @@ class VllmRunner:
|
||||
return [(output_ids[0], output_str[0])
|
||||
for output_ids, output_str in outputs]
|
||||
|
||||
def generate_greedy_logprobs(
|
||||
self,
|
||||
prompts: List[str],
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
greedy_logprobs_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=max_tokens,
|
||||
logprobs=num_logprobs)
|
||||
outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params)
|
||||
|
||||
return [(output_ids, output_str, output_logprobs)
|
||||
for output_ids, output_str, output_logprobs in outputs]
|
||||
|
||||
def generate_beam_search(
|
||||
self,
|
||||
prompts: List[str],
|
||||
|
97
tests/models/test_marlin.py
Normal file
97
tests/models/test_marlin.py
Normal file
@ -0,0 +1,97 @@
|
||||
"""Compare the outputs of a GPTQ model to a Marlin model.
|
||||
|
||||
Note: GPTQ and Marlin do not have bitwise correctness.
|
||||
As a result, in this test, we just confirm that the top selected tokens of the
|
||||
Marlin/GPTQ models are in the top 3 selections of each other.
|
||||
|
||||
Note: Marlin internally uses locks to synchronize the threads. This can
|
||||
result in very slight nondeterminism for Marlin. As a result, we re-run the test
|
||||
up to 3 times to see if we pass.
|
||||
|
||||
Run `pytest tests/models/test_marlin.py --forked`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from vllm.model_executor.layers.quantization import _QUANTIZATION_CONFIG_REGISTRY
|
||||
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
marlin_not_supported = (
|
||||
capability < _QUANTIZATION_CONFIG_REGISTRY["marlin"].get_min_capability())
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelPair:
|
||||
model_marlin: str
|
||||
model_gptq: str
|
||||
|
||||
|
||||
model_pairs = [
|
||||
ModelPair(model_marlin="nm-testing/zephyr-beta-7b-marlin-g128",
|
||||
model_gptq="nm-testing/zephyr-beta-7b-gptq-g128"),
|
||||
ModelPair(model_marlin="robertgshaw2/zephyr-7b-beta-channelwise-marlin",
|
||||
model_gptq="robertgshaw2/zephyr-7b-beta-channelwise-gptq"),
|
||||
ModelPair(model_marlin="robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin",
|
||||
model_gptq="robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-gptq")
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.flaky(reruns=2)
|
||||
@pytest.mark.skipif(marlin_not_supported,
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("model_pair", model_pairs)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("num_logprobs", [3])
|
||||
def test_models(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model_pair: ModelPair,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
marlin_model = vllm_runner(model_pair.model_marlin, dtype=dtype)
|
||||
marlin_outputs = marlin_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
||||
# Note: not sure why, but deleting just the model on Ada Lovelace
|
||||
# does not free the GPU memory. On Ampere, deleting the just model
|
||||
# frees the memory.
|
||||
del marlin_model.model.llm_engine.driver_worker
|
||||
del marlin_model
|
||||
|
||||
gptq_model = vllm_runner(model_pair.model_gptq, dtype=dtype)
|
||||
gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts,
|
||||
max_tokens,
|
||||
num_logprobs)
|
||||
|
||||
# Note: not sure why, but deleting just the model on Ada Lovelace
|
||||
# does not free the GPU memory. On Ampere, deleting the just model
|
||||
# frees the memory.
|
||||
del gptq_model.model.llm_engine.driver_worker
|
||||
del gptq_model
|
||||
|
||||
# loop through the prompts
|
||||
for prompt_idx in range(len(example_prompts)):
|
||||
gptq_output_ids, gptq_output_str, gptq_logprobs = gptq_outputs[
|
||||
prompt_idx]
|
||||
marlin_output_ids, marlin_output_str, marlin_logprobs = marlin_outputs[
|
||||
prompt_idx]
|
||||
|
||||
for idx, (gptq_output_id, marlin_output_id) in enumerate(
|
||||
zip(gptq_output_ids, marlin_output_ids)):
|
||||
# If sequence is not an exact match,
|
||||
if marlin_output_id != gptq_output_id:
|
||||
# Each predicted token must be in top 5 of the other's
|
||||
assert gptq_output_id in marlin_logprobs[idx], (
|
||||
f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\nMarlin:\t{marlin_output_str!r}"
|
||||
)
|
||||
assert marlin_output_id in gptq_logprobs[idx], (
|
||||
f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\nMarlin:\t{marlin_output_str!r}"
|
||||
)
|
||||
|
||||
# Break out since sequences will now diverge.
|
||||
break
|
@ -155,15 +155,21 @@ class ModelConfig:
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
|
||||
def _verify_quantization(self) -> None:
|
||||
supported_quantization = ["awq", "gptq", "squeezellm"]
|
||||
rocm_not_supported_quantization = ["awq"]
|
||||
supported_quantization = ["awq", "gptq", "squeezellm", "marlin"]
|
||||
rocm_not_supported_quantization = ["awq", "marlin"]
|
||||
if self.quantization is not None:
|
||||
self.quantization = self.quantization.lower()
|
||||
|
||||
# Parse quantization method from the HF model config, if available.
|
||||
hf_quant_config = getattr(self.hf_config, "quantization_config", None)
|
||||
if hf_quant_config is not None:
|
||||
|
||||
hf_quant_method = str(hf_quant_config["quant_method"]).lower()
|
||||
# If the GPTQ model is serialized in marlin format, use marlin.
|
||||
if (hf_quant_method == "gptq"
|
||||
and "is_marlin_format" in hf_quant_config
|
||||
and hf_quant_config["is_marlin_format"]):
|
||||
hf_quant_method = "marlin"
|
||||
if self.quantization is None:
|
||||
self.quantization = hf_quant_method
|
||||
elif self.quantization != hf_quant_method:
|
||||
@ -183,9 +189,11 @@ class ModelConfig:
|
||||
raise ValueError(
|
||||
f"{self.quantization} quantization is currently not supported "
|
||||
f"in ROCm.")
|
||||
logger.warning(f"{self.quantization} quantization is not fully "
|
||||
"optimized yet. The speed can be slower than "
|
||||
"non-quantized models.")
|
||||
if self.quantization != "marlin":
|
||||
logger.warning(
|
||||
f"{self.quantization} quantization is not fully "
|
||||
"optimized yet. The speed can be slower than "
|
||||
"non-quantized models.")
|
||||
|
||||
def _verify_cuda_graph(self) -> None:
|
||||
if self.max_context_len_to_capture is None:
|
||||
|
@ -17,6 +17,14 @@ from vllm.logger import init_logger
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def adjust_marlin_shard(param, shard_size, shard_offset):
|
||||
marlin_tile_size = getattr(param, "marlin_tile_size", None)
|
||||
if marlin_tile_size is None:
|
||||
return shard_size, shard_offset
|
||||
|
||||
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
|
||||
|
||||
|
||||
class LinearMethodBase(ABC):
|
||||
"""Base class for different (maybe quantized) linear methods."""
|
||||
|
||||
@ -276,6 +284,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
|
||||
# If marlin, we need to adjust the offset and size to account for the tiling.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
output_dim, shard_offset, shard_size)
|
||||
self.weight_loader(param, loaded_weight_shard, shard_id)
|
||||
@ -293,6 +306,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
|
||||
# If marlin, we need to adjust the offset and size to account for the tiling.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
param_data = param_data.narrow(output_dim, shard_offset,
|
||||
shard_size)
|
||||
start_idx = tp_rank * shard_size
|
||||
@ -372,6 +390,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
loaded_shard_id: Optional[str] = None):
|
||||
param_data = param.data
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
|
||||
if loaded_shard_id is None:
|
||||
# Loaded weight is already packed.
|
||||
if output_dim is None:
|
||||
@ -393,6 +412,11 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
|
||||
# If marlin, we need to adjust the offset and size to account for the tiling.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
output_dim, shard_offset, shard_size)
|
||||
self.weight_loader(param, loaded_weight_shard, shard_id)
|
||||
@ -417,6 +441,11 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
|
||||
# If marlin, we need to adjust the offset and size to account for the tiling.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
param_data = param_data.narrow(output_dim, shard_offset,
|
||||
shard_size)
|
||||
if loaded_shard_id == "q":
|
||||
|
@ -4,11 +4,13 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
||||
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
||||
|
||||
_QUANTIZATION_CONFIG_REGISTRY = {
|
||||
"awq": AWQConfig,
|
||||
"gptq": GPTQConfig,
|
||||
"squeezellm": SqueezeLLMConfig,
|
||||
"marlin": MarlinConfig,
|
||||
}
|
||||
|
||||
|
||||
|
210
vllm/model_executor/layers/quantization/marlin.py
Normal file
210
vllm/model_executor/layers/quantization/marlin.py
Normal file
@ -0,0 +1,210 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm._C import ops
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
|
||||
|
||||
class MarlinConfig(QuantizationConfig):
|
||||
"""Config class for Marlin.
|
||||
|
||||
Reference: https://github.com/IST-DASLab/marlin/tree/master
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group_size: int,
|
||||
) -> None:
|
||||
# Group size for the quantization.
|
||||
self.group_size = group_size
|
||||
if self.group_size != 128 and self.group_size != -1:
|
||||
raise ValueError(
|
||||
"Currently, only group size 128 and -1 (channelwise) is supported for "
|
||||
f"Marlin, but got group_size of {self.group_size}")
|
||||
|
||||
# 4 Bits packed into 32 bit datatype.
|
||||
self.pack_factor = 32 // 4
|
||||
|
||||
# Tile size used by marlin kernels.
|
||||
self.tile_size = 16
|
||||
|
||||
# Min out_features dim
|
||||
self.min_n_threads = 64
|
||||
|
||||
# Min in_features dim
|
||||
self.min_k_threads = 128
|
||||
|
||||
# Max parallel problems to solve at once (improves large batch performance)
|
||||
self.max_parallel = 16
|
||||
|
||||
# Permutation length used by the marlin kernels.
|
||||
self.perm_len = 1024
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"MarlinConfig(group_size={self.group_size}"
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "marlin"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
return cls(group_size)
|
||||
|
||||
def get_linear_method(self) -> "MarlinLinearMethod":
|
||||
return MarlinLinearMethod(self)
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class MarlinLinearMethod(LinearMethodBase):
|
||||
"""Linear method for Marlin.
|
||||
|
||||
Args:
|
||||
quant_config: The Marlin quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: MarlinConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
input_size_per_partition: int,
|
||||
output_size_per_partition: int,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> Dict[str, Any]:
|
||||
del output_size # Unused.
|
||||
|
||||
if params_dtype != torch.float16:
|
||||
raise ValueError(
|
||||
f"The params dtype must be float16, but got {params_dtype}")
|
||||
|
||||
# Validate output_size_per_partition
|
||||
if output_size_per_partition % self.quant_config.min_n_threads != 0:
|
||||
raise ValueError(
|
||||
f"Weight output_size_per_partition = {output_size_per_partition} is not divisible by min_n_threads = {self.quant_config.min_n_threads}."
|
||||
)
|
||||
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
||||
raise ValueError(
|
||||
f"Weight output_size_per_partition = {output_size_per_partition} is not divisible by pack_factor = {self.quant_config.pack_factor}."
|
||||
)
|
||||
|
||||
# Validate input_size_per_partition
|
||||
if input_size_per_partition % self.quant_config.min_k_threads != 0:
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = {input_size_per_partition} is not divisible by min_k_threads = {self.quant_config.min_k_threads}."
|
||||
)
|
||||
if self.quant_config.group_size != -1 and input_size_per_partition % self.quant_config.group_size != 0:
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = f{input_size_per_partition} is not divisible by group_size = {self.quant_config.group_size}."
|
||||
)
|
||||
|
||||
# Check that we have at least 4 tiles horizontally in the shard
|
||||
num_tiles_per_perm = self.quant_config.perm_len // (
|
||||
self.quant_config.tile_size**2)
|
||||
if output_size_per_partition % num_tiles_per_perm != 0:
|
||||
raise ValueError(
|
||||
"Each permutation group must reside on the same gpu")
|
||||
|
||||
# Quantized 4Bit weights packed into Int32.
|
||||
qweight = Parameter(
|
||||
torch.empty(
|
||||
input_size_per_partition // self.quant_config.tile_size,
|
||||
output_size_per_partition * self.quant_config.tile_size //
|
||||
self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
qweight,
|
||||
{
|
||||
"input_dim": 0,
|
||||
"output_dim": 1,
|
||||
"packed_dim": 1,
|
||||
"pack_factor": self.quant_config.pack_factor,
|
||||
"marlin_tile_size": self.quant_config.tile_size,
|
||||
},
|
||||
)
|
||||
|
||||
# Determine if channelwise or not
|
||||
input_groups = 1 if self.quant_config.group_size == -1 else input_size_per_partition // self.quant_config.group_size
|
||||
|
||||
scales = Parameter(
|
||||
torch.empty(
|
||||
input_groups,
|
||||
output_size_per_partition,
|
||||
device="cuda",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
scales,
|
||||
{
|
||||
"input_dim": None if input_groups == 1 else 0,
|
||||
"output_dim": 1,
|
||||
},
|
||||
)
|
||||
|
||||
# Allocate workspace (Used for internal locking mechanism)
|
||||
max_workspace_size = (
|
||||
output_size_per_partition //
|
||||
self.quant_config.min_n_threads) * self.quant_config.max_parallel
|
||||
workspace = Parameter(torch.zeros(max_workspace_size,
|
||||
device="cuda",
|
||||
dtype=torch.int),
|
||||
requires_grad=False)
|
||||
|
||||
return {
|
||||
"B": qweight,
|
||||
"s": scales,
|
||||
"workspace": workspace,
|
||||
}
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
weights: Dict[str, Any],
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
qweight = weights["B"]
|
||||
scales = weights["s"]
|
||||
workspace = weights["workspace"]
|
||||
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
|
||||
size_m = x_2d.shape[0]
|
||||
size_k = x_2d.shape[1]
|
||||
size_n = scales.shape[1]
|
||||
|
||||
output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m,
|
||||
size_n, size_k)
|
||||
|
||||
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output
|
Loading…
x
Reference in New Issue
Block a user