[Misc] Update marlin to use vLLMParameters (#7803)

This commit is contained in:
Dipika Sikka 2024-08-23 14:30:52 -04:00 committed by GitHub
parent 35ee2ad6b9
commit f1df5dbfd6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 41 additions and 34 deletions

View File

@ -15,4 +15,6 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
awq, casperhansen/mixtral-instruct-awq, main awq, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
marlin, nm-testing/zephyr-beta-7b-marlin-g128, main
marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main

View File

@ -22,7 +22,8 @@ logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = [ WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod" "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
"MarlinLinearMethod"
] ]

View File

@ -9,7 +9,10 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter)
logger = init_logger(__name__) logger = init_logger(__name__)
@ -132,6 +135,7 @@ class MarlinLinearMethod(LinearMethodBase):
**extra_weight_attrs, **extra_weight_attrs,
): ):
del output_size # Unused. del output_size # Unused.
weight_loader = extra_weight_attrs["weight_loader"]
if params_dtype != torch.float16: if params_dtype != torch.float16:
raise ValueError( raise ValueError(
@ -170,64 +174,64 @@ class MarlinLinearMethod(LinearMethodBase):
"Each permutation group must reside on the same gpu") "Each permutation group must reside on the same gpu")
# Quantized 4Bit weights packed into Int32. # Quantized 4Bit weights packed into Int32.
qweight = Parameter( qweight = PackedvLLMParameter(
torch.empty( data=torch.empty(
input_size_per_partition // self.quant_config.tile_size, input_size_per_partition // self.quant_config.tile_size,
output_size_per_partition * self.quant_config.tile_size // output_size_per_partition * self.quant_config.tile_size //
self.quant_config.pack_factor, self.quant_config.pack_factor,
device="cuda", device="cuda",
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, input_dim=0,
) output_dim=1,
set_weight_attrs( packed_dim=1,
qweight, packed_factor=self.quant_config.pack_factor,
{ marlin_tile_size=self.quant_config.tile_size,
"input_dim": 0, weight_loader=weight_loader)
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
"marlin_tile_size": self.quant_config.tile_size,
},
)
# Determine if channelwise or not # Determine if channelwise or not
input_groups = (1 if self.quant_config.group_size == -1 else input_groups = (1 if self.quant_config.group_size == -1 else
input_size_per_partition // input_size_per_partition //
self.quant_config.group_size) self.quant_config.group_size)
scales = Parameter( weight_scale_args = {
"data":
torch.empty( torch.empty(
input_groups, input_groups,
output_size_per_partition, output_size_per_partition,
device="cuda", device="cuda",
dtype=params_dtype, dtype=params_dtype,
), ),
requires_grad=False, "weight_loader":
) weight_loader
set_weight_attrs( }
scales, if input_groups == 1:
{ scales = ChannelQuantScaleParameter(output_dim=1,
"input_dim": None if input_groups == 1 else 0, **weight_scale_args)
"output_dim": 1, else:
}, scales = GroupQuantScaleParameter(output_dim=1,
) input_dim=0,
**weight_scale_args)
# Allocate workspace (Used for internal locking mechanism) # Allocate workspace (Used for internal locking mechanism)
max_workspace_size = ( max_workspace_size = (
output_size_per_partition // output_size_per_partition //
self.quant_config.min_n_threads) * self.quant_config.max_parallel self.quant_config.min_n_threads) * self.quant_config.max_parallel
workspace = Parameter(torch.zeros(max_workspace_size,
device="cuda", workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size,
dtype=torch.int), device="cuda",
requires_grad=False) dtype=torch.int),
weight_loader=weight_loader)
layer.register_parameter("B", qweight) layer.register_parameter("B", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("s", scales) layer.register_parameter("s", scales)
set_weight_attrs(scales, extra_weight_attrs)
layer.register_parameter("workspace", workspace) layer.register_parameter("workspace", workspace)
set_weight_attrs(workspace, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# required by torch.compile
layer.B = Parameter(layer.B.data, requires_grad=False)
layer.s = Parameter(layer.s.data, requires_grad=False)
layer.workspace = Parameter(layer.workspace.data, requires_grad=False)
def apply( def apply(
self, self,