[ Misc ] Refactor Marlin Python Utilities (#6082)
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
parent
55f692b46e
commit
b675069d74
@ -5,14 +5,16 @@ import torch.utils.benchmark as benchmark
|
||||
from benchmark_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
MarlinWorkspace, marlin_24_quantize, marlin_quantize)
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
MarlinWorkspace, marlin_quantize)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
||||
marlin_24_quantize)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
gptq_pack, quantize_weights, sort_weights)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
@ -5,19 +5,21 @@ Run `pytest tests/kernels/marlin/test_marlin_gemm.py`.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS,
|
||||
marlin_permute_scales)
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_perms import (
|
||||
marlin_perm)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
MarlinWorkspace, compute_max_diff, is_marlin_supported, marlin_24_quantize,
|
||||
marlin_quantize, marlin_weights, pack_fp8_to_int32)
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS,
|
||||
marlin_permute_scales)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
pack_fp8_to_int32)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
MarlinWorkspace, get_weight_perm, marlin_quantize, marlin_weights)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
||||
marlin_24_quantize)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
gptq_pack, quantize_weights, sort_weights)
|
||||
|
||||
@ -42,11 +44,16 @@ MNK_FACTORS = [
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
|
||||
|
||||
def compute_max_diff(output, output_ref):
|
||||
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
||||
torch.abs(output_ref))
|
||||
|
||||
|
||||
def rand_data(shape, dtype=torch.float16):
|
||||
return torch.randn(shape, dtype=dtype, device="cuda")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_marlin_supported(),
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@ -93,8 +100,8 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
|
||||
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
|
||||
|
||||
# Pack to Marlin format
|
||||
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits,
|
||||
marlin_perm[num_bits])
|
||||
weight_perm = get_weight_perm(num_bits)
|
||||
marlin_q_w_1 = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
|
||||
|
||||
# Run Marlin repack GPU kernel
|
||||
marlin_q_w_2 = ops.gptq_marlin_repack(
|
||||
@ -109,7 +116,7 @@ def test_marlin_repack(k_chunk, n_chunk, num_bits, group_size, act_order,
|
||||
assert torch.allclose(marlin_q_w_1, marlin_q_w_2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_marlin_supported(),
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@ -174,7 +181,7 @@ def test_marlin_gemm(
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_marlin_supported(),
|
||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_24_N_CHUNKS)
|
||||
@ -222,7 +229,7 @@ def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_marlin_supported(),
|
||||
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@ -268,13 +275,10 @@ def test_fp8_marlin_gemm(
|
||||
# expand it to channelwise
|
||||
scales = weight_scale.repeat(1, size_n).to(a_input.dtype).to("cuda")
|
||||
# Permute scales
|
||||
marlin_scales = marlin_permute_scales(
|
||||
s=scales,
|
||||
marlin_scales = marlin_permute_scales(s=scales,
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
group_size=-1,
|
||||
num_bits=8,
|
||||
)
|
||||
group_size=-1)
|
||||
|
||||
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
|
@ -6,7 +6,6 @@ Run `pytest tests/quantization/test_compressed_tensors.py`.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensorsLinearMethod, CompressedTensorsW4A16Sparse24,
|
||||
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
|
||||
@ -57,12 +56,14 @@ def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
|
||||
assert qkv_proj.weight_scale.dtype is torch.float32
|
||||
assert qkv_proj.input_scale.dtype is torch.float32
|
||||
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||
assert output
|
||||
|
||||
|
||||
def test_compressed_tensors_no_enforce_eager(vllm_runner):
|
||||
model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
|
||||
with vllm_runner(model_path) as llm:
|
||||
sampling_params = SamplingParams()
|
||||
output = llm.generate("Hello world!", sampling_params=sampling_params)
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||
assert output
|
||||
|
||||
|
||||
@ -84,13 +85,16 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args):
|
||||
assert qkv_proj.scheme.strategy == strategy
|
||||
assert qkv_proj.weight.dtype is torch.int8
|
||||
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"wNa16_args",
|
||||
[("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8),
|
||||
("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8),
|
||||
("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4)])
|
||||
def test_compressed_tensors_w4a16(vllm_runner, wNa16_args):
|
||||
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
|
||||
model, strategy, group, pack_factor = wNa16_args
|
||||
with vllm_runner(model) as llm:
|
||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||
@ -101,12 +105,15 @@ def test_compressed_tensors_w4a16(vllm_runner, wNa16_args):
|
||||
assert isinstance(qkv_proj.scheme, CompressedTensorsWNA16)
|
||||
|
||||
assert qkv_proj.scheme.strategy == strategy
|
||||
assert qkv_proj.scheme.group_size == group
|
||||
assert qkv_proj.scheme.group_size == (-1 if group is None else group)
|
||||
|
||||
assert qkv_proj.weight_packed.dtype is torch.int32
|
||||
assert qkv_proj.weight_scale.dtype is torch.float16
|
||||
assert qkv_proj.weight_packed.pack_factor == pack_factor
|
||||
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||
assert output
|
||||
|
||||
|
||||
def test_compressed_tensors_w4a16_marlin24(vllm_runner):
|
||||
model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
|
||||
@ -120,8 +127,7 @@ def test_compressed_tensors_w4a16_marlin24(vllm_runner):
|
||||
assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Sparse24)
|
||||
assert qkv_proj.weight_packed.dtype is torch.int32
|
||||
|
||||
sampling_params = SamplingParams()
|
||||
output = llm.generate("Hello world!", sampling_params=sampling_params)
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||
assert output
|
||||
|
||||
|
||||
@ -142,6 +148,5 @@ def test_compressed_tensors_fp8(vllm_runner):
|
||||
assert len(qkv_proj.input_scale.shape) == 0
|
||||
assert len(qkv_proj.weight_scale.shape) == 0
|
||||
|
||||
sampling_params = SamplingParams()
|
||||
output = llm.generate("Hello world!", sampling_params=sampling_params)
|
||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||
assert output
|
||||
|
@ -6,9 +6,10 @@ from torch.nn import Parameter
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQMarlinState,
|
||||
marlin_permute_scales)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
apply_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
|
||||
marlin_permute_scales, replace_tensor, verify_marlin_supported,
|
||||
verify_marlin_supports_shape)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
__all__ = ["CompressedTensorsWNA16"]
|
||||
@ -22,29 +23,40 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
num_bits: int,
|
||||
group_size: Optional[int] = None):
|
||||
self.num_bits = num_bits
|
||||
self.pack_factor = 32 // self.num_bits
|
||||
self.strategy = strategy
|
||||
|
||||
self.group_size: int
|
||||
if group_size is None:
|
||||
if self.strategy != "channel":
|
||||
raise ValueError(
|
||||
"Marlin kernels require group quantization or "
|
||||
"channelwise quantization, but found no group "
|
||||
"size and strategy is not channelwise.")
|
||||
self.group_size = -1
|
||||
else:
|
||||
self.group_size = group_size
|
||||
|
||||
if self.strategy == "group" and self.group_size is None:
|
||||
raise ValueError(
|
||||
"group_size must be given when using strategy group")
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
pass
|
||||
# Verify supported on platform.
|
||||
verify_marlin_supported(num_bits=self.num_bits,
|
||||
group_size=self.group_size,
|
||||
is_sym=True)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, input_size: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
pack_factor = 32 // self.num_bits
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
if self.group_size is not None:
|
||||
group_size = self.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
# If group_size is -1, we are in channelwise case.
|
||||
group_size = input_size if self.group_size == -1 else self.group_size
|
||||
|
||||
verify_marlin_supports_shape(
|
||||
output_size_per_partition=output_size_per_partition,
|
||||
input_size_per_partition=input_size_per_partition,
|
||||
input_size=input_size,
|
||||
group_size=group_size)
|
||||
|
||||
weight_scale_dim = None
|
||||
scales_and_zp_size = input_size // group_size
|
||||
@ -57,7 +69,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
weight = Parameter(
|
||||
torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // pack_factor,
|
||||
input_size_per_partition // self.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
@ -68,7 +80,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
"packed_dim": 1,
|
||||
"pack_factor": pack_factor,
|
||||
"pack_factor": self.pack_factor,
|
||||
"weight_loader": weight_loader
|
||||
})
|
||||
layer.register_parameter("weight_packed", weight)
|
||||
@ -103,73 +115,48 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
|
||||
layer.input_size = input_size
|
||||
layer.marlin_state = GPTQMarlinState.REPACK
|
||||
layer.is_k_full = True
|
||||
layer.group_size = group_size
|
||||
|
||||
max_workspace_size = (
|
||||
output_size_per_partition //
|
||||
GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
|
||||
# Checkpoints are serialized in compressed-tensors format, which is
|
||||
# different from marlin format. Handle repacking here.
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
device = layer.weight_packed.device
|
||||
|
||||
workspace = torch.zeros(max_workspace_size,
|
||||
dtype=torch.int,
|
||||
requires_grad=False)
|
||||
layer.workspace = workspace
|
||||
# Allocate marlin workspace.
|
||||
layer.workspace = marlin_make_workspace(
|
||||
layer.output_size_per_partition, device)
|
||||
|
||||
# Act-order not supported in compressed-tensors yet, so set to empty.
|
||||
layer.g_idx = marlin_make_empty_g_idx(device)
|
||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
||||
|
||||
# Repack weights from compressed-tensors format to marlin format.
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
layer.weight_packed.t().contiguous(),
|
||||
perm=layer.g_idx_sort_indices,
|
||||
size_k=layer.input_size_per_partition,
|
||||
size_n=layer.output_size_per_partition,
|
||||
num_bits=self.num_bits)
|
||||
replace_tensor(layer, "weight_packed", marlin_qweight)
|
||||
|
||||
# Permute scales from compressed-tensors format to marlin format.
|
||||
marlin_scales = marlin_permute_scales(
|
||||
layer.weight_scale.squeeze().t().contiguous(),
|
||||
size_k=layer.input_size_per_partition,
|
||||
size_n=layer.output_size_per_partition,
|
||||
group_size=layer.group_size)
|
||||
replace_tensor(layer, "weight_scale", marlin_scales)
|
||||
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
size_m = reshaped_x.shape[0]
|
||||
part_size_n = layer.output_size_per_partition
|
||||
part_size_k = layer.input_size_per_partition
|
||||
|
||||
out_shape = x.shape[:-1] + (part_size_n, )
|
||||
|
||||
if layer.marlin_state == GPTQMarlinState.REPACK:
|
||||
layer.marlin_state = GPTQMarlinState.READY
|
||||
|
||||
# Newly generated tensors need to replace existing tensors that are
|
||||
# already registered as parameters by vLLM (and won't be freed)
|
||||
def replace_tensor(name, new_t):
|
||||
# It is important to use resize_() here since it ensures
|
||||
# the same buffer is reused
|
||||
getattr(layer, name).resize_(new_t.shape)
|
||||
getattr(layer, name).copy_(new_t)
|
||||
del new_t
|
||||
|
||||
cur_device = layer.weight_packed.device
|
||||
|
||||
# Reset g_idx related tensors
|
||||
layer.g_idx = Parameter(torch.empty(0,
|
||||
dtype=torch.int,
|
||||
device=cur_device),
|
||||
requires_grad=False)
|
||||
layer.g_idx_sort_indices = Parameter(torch.empty(
|
||||
0, dtype=torch.int, device=cur_device),
|
||||
requires_grad=False)
|
||||
|
||||
# Repack weights
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
layer.weight_packed.t().contiguous(), layer.g_idx_sort_indices,
|
||||
part_size_k, part_size_n, self.num_bits)
|
||||
|
||||
replace_tensor("weight_packed", marlin_qweight)
|
||||
|
||||
# Permute scales
|
||||
scales_size_k = part_size_k
|
||||
scales_size_n = part_size_n
|
||||
|
||||
marlin_scales = marlin_permute_scales(
|
||||
layer.weight_scale.squeeze().t().contiguous(), scales_size_k,
|
||||
scales_size_n, layer.group_size, self.num_bits)
|
||||
replace_tensor("weight_scale", marlin_scales)
|
||||
|
||||
output = ops.gptq_marlin_gemm(reshaped_x, layer.weight_packed,
|
||||
layer.weight_scale, layer.g_idx,
|
||||
layer.g_idx_sort_indices,
|
||||
layer.workspace, self.num_bits, size_m,
|
||||
part_size_n, part_size_k,
|
||||
layer.is_k_full)
|
||||
return output.reshape(out_shape)
|
||||
return apply_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight_packed,
|
||||
weight_scale=layer.weight_scale,
|
||||
g_idx=layer.g_idx,
|
||||
g_idx_sort_indices=layer.g_idx_sort_indices,
|
||||
workspace=layer.workspace,
|
||||
num_bits=self.num_bits,
|
||||
output_size_per_partition=layer.output_size_per_partition,
|
||||
input_size_per_partition=layer.input_size_per_partition,
|
||||
is_k_full=True)
|
||||
|
@ -11,7 +11,7 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d, apply_fp8_linear, create_per_tensor_scale_param,
|
||||
|
@ -1,5 +1,3 @@
|
||||
import enum
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
@ -12,46 +10,14 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_K,
|
||||
GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_SUPPORTED_GROUP_SIZES,
|
||||
GPTQ_MARLIN_SUPPORTED_NUM_BITS, GPTQ_MARLIN_SUPPORTED_SYM,
|
||||
GPTQ_MARLIN_TILE)
|
||||
check_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace,
|
||||
marlin_permute_scales, marlin_sort_g_idx, replace_tensor,
|
||||
verify_marlin_supported, verify_marlin_supports_shape)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# Permutations for Marlin scale shuffling
|
||||
def get_scale_perms(num_bits: int):
|
||||
scale_perm: List[int] = []
|
||||
for i in range(8):
|
||||
scale_perm.extend([i + 8 * j for j in range(8)])
|
||||
scale_perm_single: List[int] = []
|
||||
for i in range(4):
|
||||
scale_perm_single.extend(
|
||||
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
||||
return scale_perm, scale_perm_single
|
||||
|
||||
|
||||
def get_pack_factor(num_bits: int):
|
||||
assert (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
|
||||
), f"Unsupported num_bits = {num_bits}"
|
||||
return 32 // num_bits
|
||||
|
||||
|
||||
def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
|
||||
group_size: int, num_bits: int):
|
||||
scale_perm, scale_perm_single = get_scale_perms(num_bits)
|
||||
if group_size < size_k and group_size != -1:
|
||||
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
|
||||
else:
|
||||
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
|
||||
s = s.reshape((-1, size_n)).contiguous()
|
||||
|
||||
return s
|
||||
|
||||
|
||||
class GPTQMarlinConfig(QuantizationConfig):
|
||||
"""Config class for GPTQ Marlin"""
|
||||
|
||||
@ -63,33 +29,16 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
desc_act = False
|
||||
|
||||
self.weight_bits = weight_bits
|
||||
self.pack_factor = 32 // self.weight_bits # packed into int32
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
self.is_sym = is_sym
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
|
||||
# Verify
|
||||
if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
|
||||
raise ValueError(
|
||||
f"Marlin does not support weight_bits = {self.weight_bits}. "
|
||||
f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} "
|
||||
"are supported.")
|
||||
if self.group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
|
||||
raise ValueError(
|
||||
f"Marlin does not support group_size = {self.group_size}. "
|
||||
f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} "
|
||||
"are supported.")
|
||||
if self.is_sym not in GPTQ_MARLIN_SUPPORTED_SYM:
|
||||
raise ValueError(
|
||||
f"Marlin does not support is_sym = {self.is_sym}. "
|
||||
f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.")
|
||||
|
||||
# Init
|
||||
self.pack_factor = get_pack_factor(weight_bits)
|
||||
self.tile_size = GPTQ_MARLIN_TILE
|
||||
self.min_thread_n = GPTQ_MARLIN_MIN_THREAD_N
|
||||
self.min_thread_k = GPTQ_MARLIN_MIN_THREAD_K
|
||||
self.max_parallel = GPTQ_MARLIN_MAX_PARALLEL
|
||||
# Verify supported on platform.
|
||||
verify_marlin_supported(num_bits=self.weight_bits,
|
||||
group_size=self.group_size,
|
||||
is_sym=self.is_sym)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
|
||||
@ -168,21 +117,10 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
or desc_act is None):
|
||||
return False
|
||||
|
||||
# If the capability of the device is too low, cannot convert.
|
||||
major, minor = current_platform.get_device_capability()
|
||||
device_capability = major * 10 + minor
|
||||
if device_capability < cls.get_min_capability():
|
||||
return False
|
||||
|
||||
# Otherwise, can convert if model satisfies marlin constraints.
|
||||
return (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
|
||||
and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
|
||||
and sym in GPTQ_MARLIN_SUPPORTED_SYM)
|
||||
|
||||
|
||||
class GPTQMarlinState(Enum):
|
||||
REPACK = enum.auto()
|
||||
READY = enum.auto()
|
||||
return check_marlin_supported(num_bits=num_bits,
|
||||
group_size=group_size,
|
||||
is_sym=sym,
|
||||
min_capability=cls.get_min_capability())
|
||||
|
||||
|
||||
class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
@ -206,6 +144,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
del output_size
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
# Normalize group_size
|
||||
if self.quant_config.group_size != -1:
|
||||
@ -213,31 +152,11 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
# Validate dtype
|
||||
if params_dtype not in [torch.float16, torch.bfloat16]:
|
||||
raise ValueError(f"The params dtype must be float16 "
|
||||
f"or bfloat16, but got {params_dtype}")
|
||||
|
||||
# Validate output_size_per_partition
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if output_size_per_partition % self.quant_config.min_thread_n != 0:
|
||||
raise ValueError(
|
||||
f"Weight output_size_per_partition = "
|
||||
f"{output_size_per_partition} is not divisible by "
|
||||
f" min_thread_n = {self.quant_config.min_thread_n}.")
|
||||
|
||||
# Validate input_size_per_partition
|
||||
if input_size_per_partition % self.quant_config.min_thread_k != 0:
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = "
|
||||
f"{input_size_per_partition} is not divisible "
|
||||
f"by min_thread_k = {self.quant_config.min_thread_k}.")
|
||||
|
||||
if (group_size < input_size
|
||||
and input_size_per_partition % group_size != 0):
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = {input_size_per_partition}"
|
||||
f" is not divisible by group_size = {group_size}.")
|
||||
verify_marlin_supports_shape(
|
||||
output_size_per_partition=output_size_per_partition,
|
||||
input_size_per_partition=input_size_per_partition,
|
||||
input_size=input_size,
|
||||
group_size=group_size)
|
||||
|
||||
# Detect sharding of scales/zp
|
||||
|
||||
@ -303,11 +222,6 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
},
|
||||
)
|
||||
|
||||
g_idx_sort_indices = torch.empty(
|
||||
g_idx.shape,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
# Scales
|
||||
scales = Parameter(
|
||||
torch.empty(
|
||||
@ -347,25 +261,50 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
},
|
||||
)
|
||||
|
||||
# Allocate marlin workspace
|
||||
max_workspace_size = (
|
||||
output_size_per_partition //
|
||||
self.quant_config.min_thread_n) * self.quant_config.max_parallel
|
||||
workspace = torch.zeros(max_workspace_size,
|
||||
dtype=torch.int,
|
||||
requires_grad=False)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("g_idx", g_idx)
|
||||
layer.register_parameter("scales", scales)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
layer.g_idx_sort_indices = g_idx_sort_indices
|
||||
layer.workspace = workspace
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.input_size = input_size
|
||||
layer.is_k_full = is_k_full
|
||||
layer.marlin_state = GPTQMarlinState.REPACK
|
||||
|
||||
# Checkpoints are serialized in AutoGPTQ format, which is different from the
|
||||
# marlin format. This function is called after the weights are loaded.
|
||||
# Here, we handle the repacking, including the activation reordering case.
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
device = layer.qweight.device
|
||||
# Allocate marlin workspace
|
||||
layer.workspace = marlin_make_workspace(
|
||||
layer.output_size_per_partition, device)
|
||||
|
||||
# Handle sorting for activation reordering if needed.
|
||||
if self.quant_config.desc_act:
|
||||
g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.g_idx)
|
||||
layer.g_idx_sort_indices = g_idx_sort_indices
|
||||
replace_tensor(layer, "g_idx", g_idx)
|
||||
else:
|
||||
layer.g_idx = marlin_make_empty_g_idx(device)
|
||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
||||
|
||||
# Repack weights from autogptq format to marlin format.
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
layer.qweight,
|
||||
perm=layer.g_idx_sort_indices,
|
||||
size_k=layer.input_size_per_partition,
|
||||
size_n=layer.output_size_per_partition,
|
||||
num_bits=self.quant_config.weight_bits)
|
||||
replace_tensor(layer, "qweight", marlin_qweight)
|
||||
|
||||
# Permute scales from autogptq format to marlin format.
|
||||
marlin_scales = marlin_permute_scales(
|
||||
layer.scales,
|
||||
size_k=(layer.input_size if self.quant_config.desc_act else
|
||||
layer.input_size_per_partition),
|
||||
size_n=layer.output_size_per_partition,
|
||||
group_size=self.quant_config.group_size)
|
||||
replace_tensor(layer, "scales", marlin_scales)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@ -374,87 +313,19 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (layer.output_size_per_partition, )
|
||||
|
||||
size_m = reshaped_x.shape[0]
|
||||
part_size_n = layer.output_size_per_partition
|
||||
part_size_k = layer.input_size_per_partition
|
||||
full_size_k = layer.input_size
|
||||
|
||||
out_shape = x.shape[:-1] + (part_size_n, )
|
||||
|
||||
if layer.marlin_state == GPTQMarlinState.REPACK:
|
||||
layer.marlin_state = GPTQMarlinState.READY
|
||||
|
||||
# Newly generated tensors need to replace existing tensors that are
|
||||
# already registered as parameters by vLLM (and won't be freed)
|
||||
def replace_tensor(name, new_t):
|
||||
# It is important to use resize_() here since it ensures
|
||||
# the same buffer is reused
|
||||
getattr(layer, name).resize_(new_t.shape)
|
||||
getattr(layer, name).copy_(new_t)
|
||||
del new_t
|
||||
|
||||
cur_device = layer.qweight.device
|
||||
|
||||
# Process act_order
|
||||
if self.quant_config.desc_act:
|
||||
# Get sorting based on g_idx
|
||||
g_idx_sort_indices = torch.argsort(layer.g_idx).to(torch.int)
|
||||
|
||||
sorted_g_idx = layer.g_idx[g_idx_sort_indices]
|
||||
|
||||
replace_tensor("g_idx", sorted_g_idx)
|
||||
replace_tensor("g_idx_sort_indices", g_idx_sort_indices)
|
||||
|
||||
else:
|
||||
# Reset g_idx related tensors
|
||||
layer.g_idx = Parameter(
|
||||
torch.empty(0, dtype=torch.int, device=cur_device),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.g_idx_sort_indices = Parameter(
|
||||
torch.empty(0, dtype=torch.int, device=cur_device),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# Repack weights
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
layer.qweight,
|
||||
layer.g_idx_sort_indices,
|
||||
part_size_k,
|
||||
part_size_n,
|
||||
self.quant_config.weight_bits,
|
||||
)
|
||||
replace_tensor("qweight", marlin_qweight)
|
||||
|
||||
# Permute scales
|
||||
scales_size_k = part_size_k
|
||||
scales_size_n = part_size_n
|
||||
if self.quant_config.desc_act:
|
||||
scales_size_k = full_size_k
|
||||
|
||||
marlin_scales = marlin_permute_scales(
|
||||
layer.scales,
|
||||
scales_size_k,
|
||||
scales_size_n,
|
||||
self.quant_config.group_size,
|
||||
self.quant_config.weight_bits,
|
||||
)
|
||||
replace_tensor("scales", marlin_scales)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
reshaped_x,
|
||||
output = ops.gptq_marlin_gemm(reshaped_x,
|
||||
layer.qweight,
|
||||
layer.scales,
|
||||
layer.g_idx,
|
||||
layer.g_idx_sort_indices,
|
||||
layer.workspace,
|
||||
self.quant_config.weight_bits,
|
||||
size_m,
|
||||
part_size_n,
|
||||
part_size_k,
|
||||
layer.is_k_full,
|
||||
)
|
||||
g_idx=layer.g_idx,
|
||||
perm=layer.g_idx_sort_indices,
|
||||
workspace=layer.workspace,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
size_m=reshaped_x.shape[0],
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
is_k_full=layer.is_k_full)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
@ -1,60 +0,0 @@
|
||||
"""This file is used for /tests and /benchmarks"""
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
|
||||
# Precompute permutations for Marlin24 weight and scale shuffling # noqa: E501
|
||||
#
|
||||
# Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501
|
||||
# with the tensor-core format that is described here:
|
||||
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
|
||||
#
|
||||
# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
|
||||
# (without the need to use ldmatrix instructions) # noqa: E501
|
||||
def get_perms_24(num_bits: int):
|
||||
perm_list: List[int] = []
|
||||
for i in range(32):
|
||||
perm1: List[int] = []
|
||||
col = i // 4
|
||||
col_o = col // 2
|
||||
for block in [0, 1]:
|
||||
for row in [
|
||||
2 * (i % 4),
|
||||
2 * (i % 4) + 1,
|
||||
2 * (i % 4 + 4),
|
||||
2 * (i % 4 + 4) + 1,
|
||||
]:
|
||||
perm1.append(16 * row + col_o * 256 + 8 * (col % 2) +
|
||||
4 * block)
|
||||
for j in range(4):
|
||||
perm_list.extend([p + 1 * j for p in perm1])
|
||||
perm = numpy.array(perm_list)
|
||||
|
||||
if num_bits == 4:
|
||||
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
||||
elif num_bits == 8:
|
||||
interleave = numpy.array([0, 2, 1, 3])
|
||||
else:
|
||||
raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))
|
||||
|
||||
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
||||
perm = torch.from_numpy(perm)
|
||||
scale_perm: List[int] = []
|
||||
for i in range(8):
|
||||
scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
|
||||
scale_perm_single: List[int] = []
|
||||
for i in range(8):
|
||||
scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
|
||||
return perm, scale_perm, scale_perm_single
|
||||
|
||||
|
||||
marlin_24_perm: Dict[int, torch.Tensor] = {}
|
||||
marlin_24_scale_perm: Dict[int, List[int]] = {}
|
||||
marlin_24_scale_perm_single: Dict[int, List[int]] = {}
|
||||
for num_bits in [4, 8]:
|
||||
perm_24, scale_perm_24, scale_perm_single_24 = get_perms_24(num_bits)
|
||||
marlin_24_perm[num_bits] = perm_24
|
||||
marlin_24_scale_perm[num_bits] = scale_perm_24
|
||||
marlin_24_scale_perm_single[num_bits] = scale_perm_single_24
|
@ -1,60 +0,0 @@
|
||||
"""This file is used for /tests and /benchmarks"""
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
|
||||
# Precompute permutations for Marlin weight and scale shuffling # noqa: E501
|
||||
#
|
||||
# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501
|
||||
# with the tensor-core format that is described here:
|
||||
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
|
||||
#
|
||||
# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
|
||||
# (without the need to use ldmatrix instructions) # noqa: E501
|
||||
def get_perms(num_bits: int):
|
||||
perm_list: List[int] = []
|
||||
for i in range(32):
|
||||
perm1: List[int] = []
|
||||
col = i // 4
|
||||
for block in [0, 1]:
|
||||
for row in [
|
||||
2 * (i % 4),
|
||||
2 * (i % 4) + 1,
|
||||
2 * (i % 4 + 4),
|
||||
2 * (i % 4 + 4) + 1,
|
||||
]:
|
||||
perm1.append(16 * row + col + 8 * block)
|
||||
for j in range(4):
|
||||
perm_list.extend([p + 256 * j for p in perm1])
|
||||
|
||||
perm = numpy.array(perm_list)
|
||||
|
||||
if num_bits == 4:
|
||||
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
||||
elif num_bits == 8:
|
||||
interleave = numpy.array([0, 2, 1, 3])
|
||||
else:
|
||||
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
||||
|
||||
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
||||
perm = torch.from_numpy(perm)
|
||||
scale_perm: List[int] = []
|
||||
for i in range(8):
|
||||
scale_perm.extend([i + 8 * j for j in range(8)])
|
||||
scale_perm_single: List[int] = []
|
||||
for i in range(4):
|
||||
scale_perm_single.extend(
|
||||
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
||||
return perm, scale_perm, scale_perm_single
|
||||
|
||||
|
||||
marlin_perm: Dict[int, torch.Tensor] = {}
|
||||
marlin_scale_perm: Dict[int, List[int]] = {}
|
||||
marlin_scale_perm_single: Dict[int, List[int]] = {}
|
||||
for num_bits in [4, 8]:
|
||||
perm, scale_perm, scale_perm_single = get_perms(num_bits)
|
||||
marlin_perm[num_bits] = perm
|
||||
marlin_scale_perm[num_bits] = scale_perm
|
||||
marlin_scale_perm_single[num_bits] = scale_perm_single
|
@ -1,21 +1,9 @@
|
||||
"""This file is used for /tests and /benchmarks"""
|
||||
import random
|
||||
from typing import Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.format_24 import (
|
||||
mask_creator, sparse_semi_structured_from_dense_cutlass)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_24_perms import (
|
||||
marlin_24_perm, marlin_24_scale_perm, marlin_24_scale_perm_single)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_perms import (
|
||||
marlin_perm, marlin_scale_perm, marlin_scale_perm_single)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
get_pack_factor, quantize_weights, sort_weights)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
GPTQ_MARLIN_TILE = 16
|
||||
GPTQ_MARLIN_MIN_THREAD_N = 64
|
||||
@ -25,135 +13,110 @@ GPTQ_MARLIN_MAX_PARALLEL = 16
|
||||
GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
|
||||
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
GPTQ_MARLIN_SUPPORTED_SYM = [True]
|
||||
GTPQ_MARLIN_UNSUPPORTED_GROUP_SIZE_ACT_ORDER = [-1]
|
||||
|
||||
|
||||
def is_marlin_supported():
|
||||
capability = current_platform.get_device_capability()
|
||||
return capability[0] >= 8
|
||||
def check_marlin_supported(num_bits: int, group_size: int, is_sym: bool,
|
||||
min_capability: int) -> bool:
|
||||
|
||||
# If the capability of the device is too low, cannot convert.
|
||||
major, minor = current_platform.get_device_capability()
|
||||
device_capability = major * 10 + minor
|
||||
if device_capability < min_capability:
|
||||
return False
|
||||
|
||||
return (device_capability >= min_capability
|
||||
and num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
|
||||
and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
|
||||
and is_sym in GPTQ_MARLIN_SUPPORTED_SYM)
|
||||
|
||||
|
||||
def apply_fp8_marlin_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the
|
||||
# Marlin kernel for fast weight-only FP8 quantization
|
||||
def verify_marlin_supported(num_bits: int, group_size: Optional[int],
|
||||
is_sym: bool) -> None:
|
||||
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
out_shape = input.shape[:-1] + (size_n, )
|
||||
|
||||
output = ops.fp8_marlin_gemm(
|
||||
a=reshaped_x,
|
||||
b_q_weight=weight,
|
||||
b_scales=weight_scale,
|
||||
workspace=workspace,
|
||||
num_bits=8,
|
||||
size_m=reshaped_x.shape[0],
|
||||
size_n=size_n,
|
||||
size_k=size_k,
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
if num_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
|
||||
raise ValueError(
|
||||
f"Marlin does not support weight_bits = {num_bits}. "
|
||||
f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} "
|
||||
"are supported.")
|
||||
if (group_size is None
|
||||
or group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES):
|
||||
raise ValueError(
|
||||
f"Marlin does not support group_size = {group_size}. "
|
||||
f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} "
|
||||
"are supported.")
|
||||
if is_sym not in GPTQ_MARLIN_SUPPORTED_SYM:
|
||||
raise ValueError(
|
||||
f"Marlin does not support is_sym = is_sym. "
|
||||
f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.")
|
||||
|
||||
|
||||
def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
print_warning_once(
|
||||
"Your GPU does not have native support for FP8 computation but "
|
||||
"FP8 quantization is being used. Weight-only FP8 compression will "
|
||||
"be used leveraging the Marlin kernel. This may degrade "
|
||||
"performance for compute-heavy workloads.")
|
||||
def verify_marlin_supports_shape(output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
input_size: int, group_size: int) -> None:
|
||||
|
||||
part_size_n = layer.output_size_per_partition
|
||||
part_size_k = layer.input_size_per_partition
|
||||
# Validate output_size_per_partition
|
||||
if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
|
||||
raise ValueError(f"Weight output_size_per_partition = "
|
||||
f"{output_size_per_partition} is not divisible by "
|
||||
f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
|
||||
"Consider reducing tensor_parallel_size or running "
|
||||
"with --quantization gptq.")
|
||||
|
||||
device = layer.weight.device
|
||||
# Validate input_size_per_partition
|
||||
if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
|
||||
raise ValueError(f"Weight input_size_per_partition = "
|
||||
f"{input_size_per_partition} is not divisible "
|
||||
f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
|
||||
"Consider reducing tensor_parallel_size or running "
|
||||
"with --quantization gptq.")
|
||||
|
||||
# WEIGHTS
|
||||
# Repack weights to gptq format (packed int32 elements)
|
||||
packed_gptq_qweight = pack_fp8_to_int32(layer.weight)
|
||||
if (group_size < input_size
|
||||
and input_size_per_partition % group_size != 0):
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = {input_size_per_partition}"
|
||||
f" is not divisible by group_size = {group_size}."
|
||||
"Consider reducing tensor_parallel_size or running "
|
||||
"with --quantization gptq.")
|
||||
|
||||
# Repack weights to marlin format
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
b_q_weight=packed_gptq_qweight,
|
||||
perm=torch.empty(0, dtype=torch.int, device=device),
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
num_bits=8,
|
||||
)
|
||||
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
|
||||
|
||||
# WEIGHT SCALES
|
||||
# Currently Marlin doesn't support per-tensor scales, so we
|
||||
# expand it to channelwise
|
||||
scales = layer.weight_scale.repeat(1, part_size_n).to(
|
||||
layer.orig_dtype).to(device)
|
||||
# Permute scales
|
||||
num_bits = 8
|
||||
marlin_scales = marlin_permute_scales(
|
||||
s=scales,
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
group_size=-1,
|
||||
scale_perm=marlin_scale_perm[num_bits],
|
||||
scale_perm_single=marlin_scale_perm_single[num_bits])
|
||||
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
|
||||
|
||||
# Allocate marlin workspace
|
||||
max_workspace_size = (part_size_n //
|
||||
def marlin_make_workspace(output_size_per_partition: int,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
max_workspace_size = (output_size_per_partition //
|
||||
GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
|
||||
workspace = torch.zeros(max_workspace_size,
|
||||
|
||||
return torch.zeros(max_workspace_size,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
requires_grad=False)
|
||||
|
||||
layer.workspace = workspace
|
||||
|
||||
def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
|
||||
return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
|
||||
requires_grad=False)
|
||||
|
||||
|
||||
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
|
||||
assert q_w.shape == (size_k, size_n)
|
||||
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
|
||||
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
|
||||
|
||||
# Permute weights to 16x64 marlin tiles
|
||||
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
|
||||
q_w = q_w.permute((0, 2, 1, 3))
|
||||
q_w = q_w.reshape((size_k // tile, size_n * tile))
|
||||
|
||||
q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
|
||||
|
||||
return q_w
|
||||
def marlin_sort_g_idx(
|
||||
g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
|
||||
return g_idx[g_idx_sort_indices], g_idx_sort_indices
|
||||
|
||||
|
||||
def marlin_weights(q_w, size_k, size_n, num_bits, perm):
|
||||
# Permute
|
||||
q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
|
||||
|
||||
# Pack
|
||||
pack_factor = get_pack_factor(num_bits)
|
||||
orig_device = q_w.device
|
||||
|
||||
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
||||
|
||||
q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),
|
||||
dtype=numpy.uint32)
|
||||
for i in range(pack_factor):
|
||||
q_packed |= q_w[:, i::pack_factor] << num_bits * i
|
||||
|
||||
q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device)
|
||||
|
||||
return q_packed
|
||||
def get_scale_perms():
|
||||
scale_perm: List[int] = []
|
||||
for i in range(8):
|
||||
scale_perm.extend([i + 8 * j for j in range(8)])
|
||||
scale_perm_single: List[int] = []
|
||||
for i in range(4):
|
||||
scale_perm_single.extend(
|
||||
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
||||
return scale_perm, scale_perm_single
|
||||
|
||||
|
||||
def marlin_permute_scales(s, size_k, size_n, group_size, scale_perm,
|
||||
scale_perm_single):
|
||||
def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
|
||||
group_size: int) -> torch.Tensor:
|
||||
|
||||
scale_perm, scale_perm_single = get_scale_perms()
|
||||
if group_size < size_k and group_size != -1:
|
||||
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
|
||||
else:
|
||||
@ -163,180 +126,44 @@ def marlin_permute_scales(s, size_k, size_n, group_size, scale_perm,
|
||||
return s
|
||||
|
||||
|
||||
def marlin_quantize(
|
||||
w: torch.Tensor,
|
||||
# Newly generated tensors need to replace existing tensors that are
|
||||
# already registered as parameters by vLLM (and won't be freed)
|
||||
def replace_tensor(layer: torch.nn.Module, name: str,
|
||||
new_t: torch.Tensor) -> None:
|
||||
# It is important to use resize_() here since it ensures
|
||||
# the same buffer is reused
|
||||
getattr(layer, name).resize_(new_t.shape)
|
||||
getattr(layer, name).copy_(new_t)
|
||||
del new_t
|
||||
|
||||
|
||||
def apply_marlin_linear(input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
g_idx: torch.Tensor,
|
||||
g_idx_sort_indices: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
num_bits: int,
|
||||
group_size: int,
|
||||
act_order: bool,
|
||||
):
|
||||
size_k, size_n = w.shape
|
||||
output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
is_k_full: bool,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
out_shape = input.shape[:-1] + (output_size_per_partition, )
|
||||
|
||||
# Normalize group_size
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
assert group_size <= size_k
|
||||
|
||||
# Quantize (and apply act_order if provided)
|
||||
w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size,
|
||||
act_order)
|
||||
|
||||
# For act_order, sort the "weights" and "g_idx" so that group ids are
|
||||
# increasing
|
||||
sort_indices = torch.empty(0, dtype=torch.int, device=w.device)
|
||||
if act_order:
|
||||
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
|
||||
|
||||
# Reformat to marlin
|
||||
marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits,
|
||||
marlin_perm[num_bits])
|
||||
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size,
|
||||
marlin_scale_perm[num_bits],
|
||||
marlin_scale_perm_single[num_bits])
|
||||
|
||||
# Create result
|
||||
res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
|
||||
for i in range(len(res_list)):
|
||||
res_list[i] = res_list[i].to(w.device)
|
||||
|
||||
return res_list
|
||||
|
||||
|
||||
def inject_24(w, size_k, size_n):
|
||||
assert w.shape == (size_k, size_n)
|
||||
|
||||
mask = mask_creator(w.t()).t().cuda().bool()
|
||||
|
||||
return (mask * w).contiguous(), mask.contiguous()
|
||||
|
||||
|
||||
def check_24(w, num_rows_to_sample=50, _verbose=False):
|
||||
BLOCK_SIZE = 4
|
||||
MAX_NON_ZEROS = 2
|
||||
|
||||
w = w.t().contiguous()
|
||||
|
||||
print("check_24: w.shape = {}".format(w.shape))
|
||||
|
||||
num_rows, num_cols = w.shape
|
||||
sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample)
|
||||
if _verbose:
|
||||
print(f"Sampled row idxs = {sampled_row_idxs}")
|
||||
|
||||
total_segments = 0
|
||||
non_24_segments = 0
|
||||
for i in sampled_row_idxs:
|
||||
for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE):
|
||||
total_segments += 1
|
||||
block = w[i, j:j + BLOCK_SIZE]
|
||||
num_nonzero = torch.count_nonzero(block)
|
||||
if num_nonzero > MAX_NON_ZEROS:
|
||||
print("i = {} j = {} block = {}".format(i, j, block))
|
||||
non_24_segments += 1
|
||||
|
||||
print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.")
|
||||
|
||||
|
||||
def compress_quantized_24_weight(q_24, size_k, size_n, num_bits):
|
||||
assert q_24.shape == (size_k, size_n)
|
||||
|
||||
# Remove zp to normalize over 0
|
||||
max_q_val = (1 << num_bits) - 1
|
||||
zp = (max_q_val + 1) // 2
|
||||
q_24_no_zp = q_24 - zp
|
||||
|
||||
# Compress
|
||||
q_24_no_zp = q_24_no_zp.t().contiguous()
|
||||
q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(
|
||||
q_24_no_zp)
|
||||
q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()
|
||||
|
||||
# Restore zp
|
||||
q_24_comp = q_24_no_zp_comp + zp
|
||||
|
||||
# Resize meta to its actual shape (without moving any data)
|
||||
meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
|
||||
|
||||
return q_24_comp, meta
|
||||
|
||||
|
||||
def marlin_24_quantize(
|
||||
w: torch.Tensor,
|
||||
num_bits: int,
|
||||
group_size: int,
|
||||
):
|
||||
size_k, size_n = w.shape
|
||||
|
||||
# Normalize group_size
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
assert group_size <= size_k
|
||||
|
||||
# Inject 2:4 sparsity
|
||||
w_24, mask_24 = inject_24(w, size_k, size_n)
|
||||
|
||||
# Quantize
|
||||
w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24,
|
||||
output = ops.gptq_marlin_gemm(reshaped_x,
|
||||
weight,
|
||||
weight_scale,
|
||||
g_idx,
|
||||
g_idx_sort_indices,
|
||||
workspace,
|
||||
num_bits,
|
||||
group_size,
|
||||
act_order=False)
|
||||
size_m=reshaped_x.shape[0],
|
||||
size_n=output_size_per_partition,
|
||||
size_k=input_size_per_partition,
|
||||
is_k_full=is_k_full)
|
||||
|
||||
# Compress quantized weight
|
||||
q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n,
|
||||
num_bits)
|
||||
size_k_comp = size_k // 2
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
# Reformat to marlin
|
||||
marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n,
|
||||
num_bits, marlin_24_perm[num_bits])
|
||||
marlin_24_s = marlin_permute_scales(s, size_k, size_n, group_size,
|
||||
marlin_24_scale_perm[num_bits],
|
||||
marlin_24_scale_perm_single[num_bits])
|
||||
|
||||
# Create result
|
||||
res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s]
|
||||
for i in range(len(res_list)):
|
||||
res_list[i] = res_list[i].to(w.device)
|
||||
|
||||
return res_list
|
||||
|
||||
|
||||
def compute_max_diff(output, output_ref):
|
||||
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
||||
torch.abs(output_ref))
|
||||
|
||||
|
||||
class MarlinWorkspace:
|
||||
|
||||
def __init__(self, out_features, min_thread_n, max_parallel):
|
||||
assert (out_features % min_thread_n == 0), (
|
||||
"out_features = {} is undivisible by min_thread_n = {}".format(
|
||||
out_features, min_thread_n))
|
||||
|
||||
max_workspace_size = ((out_features // min_thread_n) * max_parallel)
|
||||
|
||||
self.scratch = torch.zeros(max_workspace_size,
|
||||
dtype=torch.int,
|
||||
device="cuda")
|
||||
|
||||
|
||||
def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Repack FP8 weights to gptq format (packed int32 elements)
|
||||
"""
|
||||
assert fp8_tensor.dtype == torch.float8_e4m3fn
|
||||
assert fp8_tensor.shape[0] % 4 == 0
|
||||
|
||||
# Reshape to prepare for packing
|
||||
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
|
||||
|
||||
# Convert fp8 to uint8 (byte) representation
|
||||
byte_tensor = reshaped.view(torch.uint8)
|
||||
|
||||
# Pack 4 uint8 values into one int32
|
||||
packed = (byte_tensor[:, 0].to(torch.int32) |
|
||||
(byte_tensor[:, 1].to(torch.int32) << 8) |
|
||||
(byte_tensor[:, 2].to(torch.int32) << 16) |
|
||||
(byte_tensor[:, 3].to(torch.int32) << 24))
|
||||
|
||||
return packed.view(fp8_tensor.shape[0] // 4,
|
||||
*fp8_tensor.shape[1:]).contiguous()
|
||||
return output.reshape(out_shape)
|
||||
|
@ -0,0 +1,109 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
from .marlin_utils import marlin_make_workspace, marlin_permute_scales
|
||||
|
||||
|
||||
def is_fp8_marlin_supported():
|
||||
capability = current_platform.get_device_capability()
|
||||
return capability[0] >= 8
|
||||
|
||||
|
||||
def apply_fp8_marlin_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
size_n: int,
|
||||
size_k: int,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the
|
||||
# Marlin kernel for fast weight-only FP8 quantization
|
||||
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
out_shape = input.shape[:-1] + (size_n, )
|
||||
|
||||
output = ops.fp8_marlin_gemm(
|
||||
a=reshaped_x,
|
||||
b_q_weight=weight,
|
||||
b_scales=weight_scale,
|
||||
workspace=workspace,
|
||||
num_bits=8,
|
||||
size_m=reshaped_x.shape[0],
|
||||
size_n=size_n,
|
||||
size_k=size_k,
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
|
||||
|
||||
def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
print_warning_once(
|
||||
"Your GPU does not have native support for FP8 computation but "
|
||||
"FP8 quantization is being used. Weight-only FP8 compression will "
|
||||
"be used leveraging the Marlin kernel. This may degrade "
|
||||
"performance for compute-heavy workloads.")
|
||||
|
||||
part_size_n = layer.output_size_per_partition
|
||||
part_size_k = layer.input_size_per_partition
|
||||
|
||||
device = layer.weight.device
|
||||
|
||||
# WORKSPACE
|
||||
layer.workspace = marlin_make_workspace(part_size_n, device)
|
||||
|
||||
# WEIGHT
|
||||
# Repack weights to marlin format
|
||||
marlin_qweight = ops.gptq_marlin_repack(b_q_weight=pack_fp8_to_int32(
|
||||
layer.weight),
|
||||
perm=torch.empty(0,
|
||||
dtype=torch.int,
|
||||
device=device),
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
num_bits=8)
|
||||
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
|
||||
|
||||
# WEIGHT SCALES
|
||||
# Currently Marlin doesn't support per-tensor scales, so we
|
||||
# expand it to channelwise
|
||||
scales = layer.weight_scale.repeat(1, part_size_n).to(
|
||||
layer.orig_dtype).to(device)
|
||||
# Permute scales
|
||||
marlin_scales = marlin_permute_scales(s=scales,
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
group_size=-1)
|
||||
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
|
||||
|
||||
|
||||
def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Repack FP8 weights to gptq format (packed int32 elements)
|
||||
"""
|
||||
assert fp8_tensor.dtype == torch.float8_e4m3fn
|
||||
assert fp8_tensor.shape[0] % 4 == 0
|
||||
|
||||
# Reshape to prepare for packing
|
||||
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
|
||||
|
||||
# Convert fp8 to uint8 (byte) representation
|
||||
byte_tensor = reshaped.view(torch.uint8)
|
||||
|
||||
# Pack 4 uint8 values into one int32
|
||||
packed = (byte_tensor[:, 0].to(torch.int32) |
|
||||
(byte_tensor[:, 1].to(torch.int32) << 8) |
|
||||
(byte_tensor[:, 2].to(torch.int32) << 16) |
|
||||
(byte_tensor[:, 3].to(torch.int32) << 24))
|
||||
|
||||
return packed.view(fp8_tensor.shape[0] // 4,
|
||||
*fp8_tensor.shape[1:]).contiguous()
|
@ -0,0 +1,120 @@
|
||||
"""Utility functions used for tests and benchmarks"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales
|
||||
from .quant_utils import get_pack_factor, quantize_weights, sort_weights
|
||||
|
||||
|
||||
class MarlinWorkspace:
|
||||
|
||||
def __init__(self, out_features, min_thread_n, max_parallel):
|
||||
assert (out_features % min_thread_n == 0), (
|
||||
"out_features = {} is undivisible by min_thread_n = {}".format(
|
||||
out_features, min_thread_n))
|
||||
|
||||
max_workspace_size = ((out_features // min_thread_n) * max_parallel)
|
||||
|
||||
self.scratch = torch.zeros(max_workspace_size,
|
||||
dtype=torch.int,
|
||||
device="cuda")
|
||||
|
||||
|
||||
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE):
|
||||
assert q_w.shape == (size_k, size_n)
|
||||
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
|
||||
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
|
||||
|
||||
# Permute weights to 16x64 marlin tiles
|
||||
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
|
||||
q_w = q_w.permute((0, 2, 1, 3))
|
||||
q_w = q_w.reshape((size_k // tile, size_n * tile))
|
||||
|
||||
q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
|
||||
|
||||
return q_w
|
||||
|
||||
|
||||
def marlin_weights(q_w, size_k, size_n, num_bits, perm):
|
||||
# Permute
|
||||
q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
|
||||
|
||||
# Pack
|
||||
pack_factor = get_pack_factor(num_bits)
|
||||
orig_device = q_w.device
|
||||
|
||||
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
||||
|
||||
q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),
|
||||
dtype=numpy.uint32)
|
||||
for i in range(pack_factor):
|
||||
q_packed |= q_w[:, i::pack_factor] << num_bits * i
|
||||
|
||||
q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device)
|
||||
|
||||
return q_packed
|
||||
|
||||
|
||||
def get_weight_perm(num_bits: int):
|
||||
perm_list: List[int] = []
|
||||
for i in range(32):
|
||||
perm1: List[int] = []
|
||||
col = i // 4
|
||||
for block in [0, 1]:
|
||||
for row in [
|
||||
2 * (i % 4),
|
||||
2 * (i % 4) + 1,
|
||||
2 * (i % 4 + 4),
|
||||
2 * (i % 4 + 4) + 1,
|
||||
]:
|
||||
perm1.append(16 * row + col + 8 * block)
|
||||
for j in range(4):
|
||||
perm_list.extend([p + 256 * j for p in perm1])
|
||||
|
||||
perm = numpy.array(perm_list)
|
||||
|
||||
if num_bits == 4:
|
||||
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
||||
elif num_bits == 8:
|
||||
interleave = numpy.array([0, 2, 1, 3])
|
||||
else:
|
||||
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
||||
|
||||
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
||||
perm = torch.from_numpy(perm)
|
||||
return perm
|
||||
|
||||
|
||||
def marlin_quantize(w: torch.Tensor, num_bits: int, group_size: int,
|
||||
act_order: bool):
|
||||
size_k, size_n = w.shape
|
||||
|
||||
# Normalize group_size
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
assert group_size <= size_k
|
||||
|
||||
# Quantize (and apply act_order if provided)
|
||||
w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size,
|
||||
act_order)
|
||||
|
||||
# For act_order, sort the "weights" and "g_idx" so that group ids are
|
||||
# increasing
|
||||
sort_indices = torch.empty(0, dtype=torch.int, device=w.device)
|
||||
if act_order:
|
||||
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
|
||||
|
||||
# Reformat to marlin
|
||||
weight_perm = get_weight_perm(num_bits)
|
||||
marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm)
|
||||
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size)
|
||||
|
||||
# Create result
|
||||
res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
|
||||
for i in range(len(res_list)):
|
||||
res_list[i] = res_list[i].to(w.device)
|
||||
|
||||
return res_list
|
@ -1,9 +1,14 @@
|
||||
#
|
||||
# Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es).
|
||||
#
|
||||
"""Utility functions used for tests and benchmarks"""
|
||||
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
from .marlin_utils_test import marlin_weights
|
||||
from .quant_utils import quantize_weights
|
||||
|
||||
|
||||
# This is PyTorch implementation of main part of reorder_meta()
|
||||
# function, from tools/util/include/cutlass/util/host_reorder.h file
|
||||
@ -306,3 +311,155 @@ def mask_creator(tensor):
|
||||
mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def inject_24(w, size_k, size_n):
|
||||
assert w.shape == (size_k, size_n)
|
||||
|
||||
mask = mask_creator(w.t()).t().cuda().bool()
|
||||
|
||||
return (mask * w).contiguous(), mask.contiguous()
|
||||
|
||||
|
||||
def check_24(w, num_rows_to_sample=50, _verbose=False):
|
||||
BLOCK_SIZE = 4
|
||||
MAX_NON_ZEROS = 2
|
||||
|
||||
w = w.t().contiguous()
|
||||
|
||||
print("check_24: w.shape = {}".format(w.shape))
|
||||
|
||||
num_rows, num_cols = w.shape
|
||||
sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample)
|
||||
if _verbose:
|
||||
print(f"Sampled row idxs = {sampled_row_idxs}")
|
||||
|
||||
total_segments = 0
|
||||
non_24_segments = 0
|
||||
for i in sampled_row_idxs:
|
||||
for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE):
|
||||
total_segments += 1
|
||||
block = w[i, j:j + BLOCK_SIZE]
|
||||
num_nonzero = torch.count_nonzero(block)
|
||||
if num_nonzero > MAX_NON_ZEROS:
|
||||
print("i = {} j = {} block = {}".format(i, j, block))
|
||||
non_24_segments += 1
|
||||
|
||||
print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.")
|
||||
|
||||
|
||||
def compress_quantized_24_weight(q_24, size_k, size_n, num_bits):
|
||||
assert q_24.shape == (size_k, size_n)
|
||||
|
||||
# Remove zp to normalize over 0
|
||||
max_q_val = (1 << num_bits) - 1
|
||||
zp = (max_q_val + 1) // 2
|
||||
q_24_no_zp = q_24 - zp
|
||||
|
||||
# Compress
|
||||
q_24_no_zp = q_24_no_zp.t().contiguous()
|
||||
q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(
|
||||
q_24_no_zp)
|
||||
q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()
|
||||
|
||||
# Restore zp
|
||||
q_24_comp = q_24_no_zp_comp + zp
|
||||
|
||||
# Resize meta to its actual shape (without moving any data)
|
||||
meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
|
||||
|
||||
return q_24_comp, meta
|
||||
|
||||
|
||||
def get_scale_perms_24():
|
||||
scale_perm: List[int] = []
|
||||
for i in range(8):
|
||||
scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
|
||||
scale_perm_single: List[int] = []
|
||||
for i in range(8):
|
||||
scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
|
||||
return scale_perm, scale_perm_single
|
||||
|
||||
|
||||
def get_weight_perm_24(num_bits: int):
|
||||
perm_list: List[int] = []
|
||||
for i in range(32):
|
||||
perm1: List[int] = []
|
||||
col = i // 4
|
||||
col_o = col // 2
|
||||
for block in [0, 1]:
|
||||
for row in [
|
||||
2 * (i % 4),
|
||||
2 * (i % 4) + 1,
|
||||
2 * (i % 4 + 4),
|
||||
2 * (i % 4 + 4) + 1,
|
||||
]:
|
||||
perm1.append(16 * row + col_o * 256 + 8 * (col % 2) +
|
||||
4 * block)
|
||||
for j in range(4):
|
||||
perm_list.extend([p + 1 * j for p in perm1])
|
||||
perm = numpy.array(perm_list)
|
||||
|
||||
if num_bits == 4:
|
||||
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
||||
elif num_bits == 8:
|
||||
interleave = numpy.array([0, 2, 1, 3])
|
||||
else:
|
||||
raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))
|
||||
|
||||
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
||||
perm = torch.from_numpy(perm)
|
||||
return perm
|
||||
|
||||
|
||||
def marlin_permute_scales_24(s: torch.Tensor, size_k: int, size_n: int,
|
||||
group_size: int) -> torch.Tensor:
|
||||
|
||||
scale_perm, scale_perm_single = get_scale_perms_24()
|
||||
if group_size < size_k and group_size != -1:
|
||||
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
|
||||
else:
|
||||
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
|
||||
s = s.reshape((-1, size_n)).contiguous()
|
||||
|
||||
return s
|
||||
|
||||
|
||||
def marlin_24_quantize(
|
||||
w: torch.Tensor,
|
||||
num_bits: int,
|
||||
group_size: int,
|
||||
):
|
||||
size_k, size_n = w.shape
|
||||
|
||||
# Normalize group_size
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
assert group_size <= size_k
|
||||
|
||||
# Inject 2:4 sparsity
|
||||
w_24, mask_24 = inject_24(w, size_k, size_n)
|
||||
|
||||
# Quantize
|
||||
w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24,
|
||||
num_bits,
|
||||
group_size,
|
||||
act_order=False)
|
||||
|
||||
# Compress quantized weight
|
||||
q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n,
|
||||
num_bits)
|
||||
size_k_comp = size_k // 2
|
||||
|
||||
# Reformat to marlin
|
||||
weight_perm = get_weight_perm_24(num_bits)
|
||||
marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n,
|
||||
num_bits, weight_perm)
|
||||
marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size)
|
||||
|
||||
# Create result
|
||||
res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s]
|
||||
for i in range(len(res_list)):
|
||||
res_list[i] = res_list[i].to(w.device)
|
||||
|
||||
return res_list
|
Loading…
x
Reference in New Issue
Block a user