[ Misc ] More Cleanup of Marlin (#6359)
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
parent
9da4aad44b
commit
babf52dade
@ -3,7 +3,7 @@
|
||||
# We use this for fp8, which HF does not support.
|
||||
#
|
||||
# Make sure you have lm-eval-harness installed:
|
||||
# pip install lm-eval==0.4.2
|
||||
# pip install lm-eval==0.4.3
|
||||
|
||||
usage() {
|
||||
echo``
|
||||
|
@ -10,8 +10,9 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace,
|
||||
marlin_permute_scales, marlin_sort_g_idx, replace_tensor,
|
||||
apply_marlin_linear, check_marlin_supported, marlin_is_k_full,
|
||||
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
|
||||
marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor,
|
||||
verify_marlin_supported, verify_marlin_supports_shape)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
|
||||
@ -145,6 +146,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
) -> None:
|
||||
del output_size
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
is_row_parallel = input_size != input_size_per_partition
|
||||
|
||||
# Normalize group_size
|
||||
if self.quant_config.group_size != -1:
|
||||
@ -158,32 +160,19 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
input_size=input_size,
|
||||
group_size=group_size)
|
||||
|
||||
# Detect sharding of scales/zp
|
||||
|
||||
# By default, no sharding over "input dim"
|
||||
scales_and_zp_size = input_size // group_size
|
||||
scales_and_zp_input_dim = None
|
||||
|
||||
if self.quant_config.desc_act:
|
||||
# Act-order case
|
||||
assert self.quant_config.group_size != -1
|
||||
|
||||
is_k_full = input_size_per_partition == input_size
|
||||
|
||||
# Determine sharding
|
||||
if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
|
||||
self.quant_config.group_size,
|
||||
is_row_parallel):
|
||||
# By setting scale_dim == None, weight_loader will
|
||||
# repeat the scales on each GPU in TP>1 case.
|
||||
scales_and_zp_input_dim = None
|
||||
scales_and_zp_size = input_size // group_size
|
||||
else:
|
||||
# No act-order case
|
||||
|
||||
# K is always full due to full alignment with
|
||||
# group-size and shard of scales/zp
|
||||
is_k_full = True
|
||||
|
||||
# If this is a row-parallel case, then shard scales/zp
|
||||
if (input_size != input_size_per_partition
|
||||
and self.quant_config.group_size != -1):
|
||||
scales_and_zp_size = input_size_per_partition // group_size
|
||||
scales_and_zp_input_dim = 0
|
||||
|
||||
# Init buffers
|
||||
# By setting scale_dim == 0, weight_loader will
|
||||
# shard the scales in TP>1 case.
|
||||
scales_and_zp_input_dim = 0
|
||||
scales_and_zp_size = input_size_per_partition // group_size
|
||||
|
||||
# Quantized weights
|
||||
qweight = Parameter(
|
||||
@ -268,13 +257,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.input_size = input_size
|
||||
layer.is_k_full = is_k_full
|
||||
layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act,
|
||||
is_row_parallel)
|
||||
|
||||
# Checkpoints are serialized in AutoGPTQ format, which is different from the
|
||||
# marlin format. This function is called after the weights are loaded.
|
||||
# Here, we handle the repacking, including the activation reordering case.
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
device = layer.qweight.device
|
||||
|
||||
# Allocate marlin workspace
|
||||
layer.workspace = marlin_make_workspace(
|
||||
layer.output_size_per_partition, device)
|
||||
@ -312,22 +303,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (layer.output_size_per_partition, )
|
||||
|
||||
output = ops.gptq_marlin_gemm(reshaped_x,
|
||||
layer.qweight,
|
||||
layer.scales,
|
||||
g_idx=layer.g_idx,
|
||||
perm=layer.g_idx_sort_indices,
|
||||
workspace=layer.workspace,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
size_m=reshaped_x.shape[0],
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
is_k_full=layer.is_k_full)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
return apply_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.qweight,
|
||||
weight_scale=layer.scales,
|
||||
g_idx=layer.g_idx,
|
||||
g_idx_sort_indices=layer.g_idx_sort_indices,
|
||||
workspace=layer.workspace,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
output_size_per_partition=layer.output_size_per_partition,
|
||||
input_size_per_partition=layer.input_size_per_partition,
|
||||
is_k_full=layer.is_k_full,
|
||||
bias=bias)
|
||||
|
@ -91,6 +91,18 @@ def marlin_make_workspace(output_size_per_partition: int,
|
||||
requires_grad=False)
|
||||
|
||||
|
||||
def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
|
||||
return (not act_order) or (act_order and not is_row_parallel)
|
||||
|
||||
|
||||
def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int,
|
||||
is_row_parallel: bool) -> bool:
|
||||
# Need to repeat scales on every rank if act_ordering or
|
||||
# channelwise and RowParallelLinear
|
||||
is_channelwise = group_size == -1
|
||||
return act_order or (is_channelwise and is_row_parallel)
|
||||
|
||||
|
||||
def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
|
||||
return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
|
||||
requires_grad=False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user