[Misc] Update marlin
to use vLLMParameters (#7803)
This commit is contained in:
parent
35ee2ad6b9
commit
f1df5dbfd6
@ -15,4 +15,6 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
|
||||
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
|
||||
awq, casperhansen/mixtral-instruct-awq, main
|
||||
awq_marlin, casperhansen/mixtral-instruct-awq, main
|
||||
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
|
||||
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
|
||||
marlin, nm-testing/zephyr-beta-7b-marlin-g128, main
|
||||
marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main
|
@ -22,7 +22,8 @@ logger = init_logger(__name__)
|
||||
|
||||
WEIGHT_LOADER_V2_SUPPORTED = [
|
||||
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
|
||||
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod"
|
||||
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
|
||||
"MarlinLinearMethod"
|
||||
]
|
||||
|
||||
|
||||
|
@ -9,7 +9,10 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -132,6 +135,7 @@ class MarlinLinearMethod(LinearMethodBase):
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del output_size # Unused.
|
||||
weight_loader = extra_weight_attrs["weight_loader"]
|
||||
|
||||
if params_dtype != torch.float16:
|
||||
raise ValueError(
|
||||
@ -170,64 +174,64 @@ class MarlinLinearMethod(LinearMethodBase):
|
||||
"Each permutation group must reside on the same gpu")
|
||||
|
||||
# Quantized 4Bit weights packed into Int32.
|
||||
qweight = Parameter(
|
||||
torch.empty(
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.tile_size,
|
||||
output_size_per_partition * self.quant_config.tile_size //
|
||||
self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
qweight,
|
||||
{
|
||||
"input_dim": 0,
|
||||
"output_dim": 1,
|
||||
"packed_dim": 1,
|
||||
"pack_factor": self.quant_config.pack_factor,
|
||||
"marlin_tile_size": self.quant_config.tile_size,
|
||||
},
|
||||
)
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
marlin_tile_size=self.quant_config.tile_size,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
# Determine if channelwise or not
|
||||
input_groups = (1 if self.quant_config.group_size == -1 else
|
||||
input_size_per_partition //
|
||||
self.quant_config.group_size)
|
||||
|
||||
scales = Parameter(
|
||||
weight_scale_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
input_groups,
|
||||
output_size_per_partition,
|
||||
device="cuda",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
scales,
|
||||
{
|
||||
"input_dim": None if input_groups == 1 else 0,
|
||||
"output_dim": 1,
|
||||
},
|
||||
)
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
if input_groups == 1:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1,
|
||||
**weight_scale_args)
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(output_dim=1,
|
||||
input_dim=0,
|
||||
**weight_scale_args)
|
||||
|
||||
# Allocate workspace (Used for internal locking mechanism)
|
||||
max_workspace_size = (
|
||||
output_size_per_partition //
|
||||
self.quant_config.min_n_threads) * self.quant_config.max_parallel
|
||||
workspace = Parameter(torch.zeros(max_workspace_size,
|
||||
device="cuda",
|
||||
dtype=torch.int),
|
||||
requires_grad=False)
|
||||
|
||||
workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size,
|
||||
device="cuda",
|
||||
dtype=torch.int),
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("B", qweight)
|
||||
set_weight_attrs(qweight, extra_weight_attrs)
|
||||
layer.register_parameter("s", scales)
|
||||
set_weight_attrs(scales, extra_weight_attrs)
|
||||
layer.register_parameter("workspace", workspace)
|
||||
set_weight_attrs(workspace, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# required by torch.compile
|
||||
layer.B = Parameter(layer.B.data, requires_grad=False)
|
||||
layer.s = Parameter(layer.s.data, requires_grad=False)
|
||||
layer.workspace = Parameter(layer.workspace.data, requires_grad=False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
Loading…
x
Reference in New Issue
Block a user