[ Misc ] More Cleanup of Marlin (#6359)

Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
Robert Shaw 2024-07-13 06:21:37 -04:00 committed by GitHub
parent 9da4aad44b
commit babf52dade
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 44 additions and 48 deletions

View File

@ -3,7 +3,7 @@
# We use this for fp8, which HF does not support.
#
# Make sure you have lm-eval-harness installed:
# pip install lm-eval==0.4.2
# pip install lm-eval==0.4.3
usage() {
echo``

View File

@ -10,8 +10,9 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace,
marlin_permute_scales, marlin_sort_g_idx, replace_tensor,
apply_marlin_linear, check_marlin_supported, marlin_is_k_full,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor,
verify_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
@ -145,6 +146,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
) -> None:
del output_size
output_size_per_partition = sum(output_partition_sizes)
is_row_parallel = input_size != input_size_per_partition
# Normalize group_size
if self.quant_config.group_size != -1:
@ -158,32 +160,19 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
input_size=input_size,
group_size=group_size)
# Detect sharding of scales/zp
# By default, no sharding over "input dim"
scales_and_zp_size = input_size // group_size
scales_and_zp_input_dim = None
if self.quant_config.desc_act:
# Act-order case
assert self.quant_config.group_size != -1
is_k_full = input_size_per_partition == input_size
# Determine sharding
if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
self.quant_config.group_size,
is_row_parallel):
# By setting scale_dim == None, weight_loader will
# repeat the scales on each GPU in TP>1 case.
scales_and_zp_input_dim = None
scales_and_zp_size = input_size // group_size
else:
# No act-order case
# K is always full due to full alignment with
# group-size and shard of scales/zp
is_k_full = True
# If this is a row-parallel case, then shard scales/zp
if (input_size != input_size_per_partition
and self.quant_config.group_size != -1):
scales_and_zp_size = input_size_per_partition // group_size
scales_and_zp_input_dim = 0
# Init buffers
# By setting scale_dim == 0, weight_loader will
# shard the scales in TP>1 case.
scales_and_zp_input_dim = 0
scales_and_zp_size = input_size_per_partition // group_size
# Quantized weights
qweight = Parameter(
@ -268,13 +257,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.input_size = input_size
layer.is_k_full = is_k_full
layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act,
is_row_parallel)
# Checkpoints are serialized in AutoGPTQ format, which is different from the
# marlin format. This function is called after the weights are loaded.
# Here, we handle the repacking, including the activation reordering case.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = layer.qweight.device
# Allocate marlin workspace
layer.workspace = marlin_make_workspace(
layer.output_size_per_partition, device)
@ -312,22 +303,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (layer.output_size_per_partition, )
output = ops.gptq_marlin_gemm(reshaped_x,
layer.qweight,
layer.scales,
g_idx=layer.g_idx,
perm=layer.g_idx_sort_indices,
workspace=layer.workspace,
num_bits=self.quant_config.weight_bits,
size_m=reshaped_x.shape[0],
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
is_k_full=layer.is_k_full)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
return apply_marlin_linear(
input=x,
weight=layer.qweight,
weight_scale=layer.scales,
g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
num_bits=self.quant_config.weight_bits,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
is_k_full=layer.is_k_full,
bias=bias)

View File

@ -91,6 +91,18 @@ def marlin_make_workspace(output_size_per_partition: int,
requires_grad=False)
def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
return (not act_order) or (act_order and not is_row_parallel)
def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int,
is_row_parallel: bool) -> bool:
# Need to repeat scales on every rank if act_ordering or
# channelwise and RowParallelLinear
is_channelwise = group_size == -1
return act_order or (is_channelwise and is_row_parallel)
def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
requires_grad=False)