[Core] Set linear_weights
directly on the layer (#3977)
This commit is contained in:
parent
8afca50889
commit
a10d3056da
@ -2067,7 +2067,7 @@ void gptq_shuffle
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
|
||||
vllm::gptq::shuffle_exllama_weight(
|
||||
(uint32_t*) q_weight.data_ptr(),
|
||||
q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(),
|
||||
q_perm.device().is_meta() || q_perm.numel() == 0 ? NULL : (int*) q_perm.data_ptr(),
|
||||
q_weight.size(0) * 32 / bit,
|
||||
q_weight.size(1),
|
||||
bit
|
||||
|
@ -73,7 +73,7 @@ def test_mixtral_moe(dtype: torch.dtype):
|
||||
).cuda()
|
||||
|
||||
# Load the weights
|
||||
vllm_moe.gate.linear_weights["weight"][:] = hf_moe.gate.weight.data
|
||||
vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
|
||||
for i in range(config.num_local_experts):
|
||||
weights = (hf_moe.experts[i].w1.weight.data,
|
||||
hf_moe.experts[i].w3.weight.data)
|
||||
|
@ -368,7 +368,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
def apply_weights(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.linear_method.apply_weights(
|
||||
self.base_layer.linear_weights, x, bias)
|
||||
self.base_layer, x, bias)
|
||||
_apply_lora(
|
||||
x,
|
||||
self.lora_a_stacked,
|
||||
@ -402,10 +402,6 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
if self.base_layer.skip_bias_add else None)
|
||||
return output, output_bias
|
||||
|
||||
@property
|
||||
def linear_weights(self):
|
||||
return self.base_layer.linear_weights
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
@ -505,7 +501,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
def apply_weights(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.linear_method.apply_weights(
|
||||
self.base_layer.linear_weights, x, bias)
|
||||
self.base_layer, x, bias)
|
||||
_apply_lora_packed_nslice(
|
||||
x,
|
||||
self.lora_a_stacked,
|
||||
@ -746,7 +742,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
def apply_weights(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.linear_method.apply_weights(
|
||||
self.base_layer.linear_weights, x, bias)
|
||||
self.base_layer, x, bias)
|
||||
_apply_lora_packed_nslice(
|
||||
x,
|
||||
self.lora_a_stacked,
|
||||
@ -838,7 +834,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||
output = self.base_layer.linear_method.apply_weights(
|
||||
self.base_layer.linear_weights, x)
|
||||
self.base_layer, x)
|
||||
_apply_lora(
|
||||
x,
|
||||
self.lora_a_stacked,
|
||||
|
@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -28,19 +28,24 @@ class LinearMethodBase(ABC):
|
||||
"""Base class for different (maybe quantized) linear methods."""
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(self, input_size_per_partition: int,
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_size_per_partition: int, input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
"""Create weights for a linear layer."""
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
"""Create weights for a linear layer.
|
||||
|
||||
The weights will be set as attributes of the layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(self,
|
||||
weights: Dict[str, torch.Tensor],
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""Apply the weights to the input tensor."""
|
||||
"""Apply the weights in layer to the input tensor.
|
||||
|
||||
Expects create_weights to have been called before on the layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -55,22 +60,24 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
def __init__(self, separate_bias_add: bool = False):
|
||||
self.separate_bias_add = separate_bias_add
|
||||
|
||||
def create_weights(self, input_size_per_partition: int,
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_size_per_partition: int, input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
weight = Parameter(torch.empty(output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
return {"weight": weight}
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
def apply_weights(self,
|
||||
weights: Dict[str, torch.Tensor],
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
weight = weights["weight"]
|
||||
weight = layer.weight
|
||||
if self.separate_bias_add:
|
||||
if bias is not None:
|
||||
return F.linear(x, weight) + bias
|
||||
@ -111,12 +118,9 @@ class ReplicatedLinear(torch.nn.Module):
|
||||
if linear_method is None:
|
||||
linear_method = UnquantizedLinearMethod()
|
||||
self.linear_method = linear_method
|
||||
self.linear_weights = self.linear_method.create_weights(
|
||||
self.input_size, self.output_size, self.input_size,
|
||||
self.output_size, self.params_dtype)
|
||||
for name, weight in self.linear_weights.items():
|
||||
if isinstance(weight, torch.Tensor):
|
||||
self.register_parameter(name, weight)
|
||||
self.linear_method.create_weights(self, self.input_size,
|
||||
self.output_size, self.input_size,
|
||||
self.output_size, self.params_dtype)
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size, dtype=self.params_dtype))
|
||||
@ -126,7 +130,7 @@ class ReplicatedLinear(torch.nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
output = self.linear_method.apply_weights(self.linear_weights, x, bias)
|
||||
output = self.linear_method.apply_weights(self, x, bias)
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
@ -177,13 +181,13 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
if linear_method is None:
|
||||
linear_method = UnquantizedLinearMethod()
|
||||
self.linear_method = linear_method
|
||||
self.linear_weights = self.linear_method.create_weights(
|
||||
self.input_size, self.output_size_per_partition, self.input_size,
|
||||
self.output_size, self.params_dtype)
|
||||
for name, weight in self.linear_weights.items():
|
||||
if isinstance(weight, torch.Tensor):
|
||||
self.register_parameter(name, weight)
|
||||
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
|
||||
self.linear_method.create_weights(self,
|
||||
self.input_size,
|
||||
self.output_size_per_partition,
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
self.params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size_per_partition,
|
||||
@ -211,8 +215,7 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
# Matrix multiply.
|
||||
output_parallel = self.linear_method.apply_weights(
|
||||
self.linear_weights, input_, bias)
|
||||
output_parallel = self.linear_method.apply_weights(self, input_, bias)
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = tensor_model_parallel_all_gather(output_parallel)
|
||||
@ -523,13 +526,13 @@ class RowParallelLinear(torch.nn.Module):
|
||||
if linear_method is None:
|
||||
linear_method = UnquantizedLinearMethod()
|
||||
self.linear_method = linear_method
|
||||
self.linear_weights = self.linear_method.create_weights(
|
||||
self.input_size_per_partition, self.output_size, self.input_size,
|
||||
self.output_size, self.params_dtype)
|
||||
for name, weight in self.linear_weights.items():
|
||||
if isinstance(weight, torch.Tensor):
|
||||
self.register_parameter(name, weight)
|
||||
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
|
||||
self.linear_method.create_weights(self,
|
||||
self.input_size_per_partition,
|
||||
self.output_size,
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
self.params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError("When not reduce the results, adding bias to the "
|
||||
@ -569,7 +572,7 @@ class RowParallelLinear(torch.nn.Module):
|
||||
|
||||
# Matrix multiply.
|
||||
output_parallel = self.linear_method.apply_weights(
|
||||
self.linear_weights, input_parallel)
|
||||
self, input_parallel)
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
|
@ -79,10 +79,11 @@ class AWQLinearMethod(LinearMethodBase):
|
||||
def __init__(self, quant_config: AWQConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, input_size_per_partition: int,
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_size_per_partition: int, input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
if input_size_per_partition % self.quant_config.group_size != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
@ -136,19 +137,21 @@ class AWQLinearMethod(LinearMethodBase):
|
||||
"input_dim": 0,
|
||||
"output_dim": 1,
|
||||
})
|
||||
return {
|
||||
"qweight": qweight,
|
||||
"qzeros": qzeros,
|
||||
"scales": scales,
|
||||
}
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
set_weight_attrs(qweight, extra_weight_attrs)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
set_weight_attrs(qzeros, extra_weight_attrs)
|
||||
layer.register_parameter("scales", scales)
|
||||
set_weight_attrs(scales, extra_weight_attrs)
|
||||
|
||||
def apply_weights(self,
|
||||
weights: Dict[str, Any],
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
qweight = weights["qweight"]
|
||||
scales = weights["scales"]
|
||||
qzeros = weights["qzeros"]
|
||||
qweight = layer.qweight
|
||||
scales = layer.scales
|
||||
qzeros = layer.qzeros
|
||||
pack_factor = self.quant_config.pack_factor
|
||||
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
@ -163,5 +166,5 @@ class AWQLinearMethod(LinearMethodBase):
|
||||
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
|
||||
pack_factor)
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
out.add_(bias)
|
||||
return out.reshape(out_shape)
|
||||
|
@ -89,12 +89,14 @@ class GPTQLinearMethod(LinearMethodBase):
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_size_per_partition: int,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> Dict[str, Any]:
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del output_size # Unused.
|
||||
if input_size_per_partition % self.quant_config.group_size != 0:
|
||||
raise ValueError(
|
||||
@ -179,37 +181,40 @@ class GPTQLinearMethod(LinearMethodBase):
|
||||
"input_dim": scale_and_zero_input_dim,
|
||||
"output_dim": 1,
|
||||
})
|
||||
return {
|
||||
"qweight": qweight,
|
||||
"g_idx": g_idx,
|
||||
"qzeros": qzeros,
|
||||
"scales": scales,
|
||||
"exllama_state": exllama_state,
|
||||
}
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
set_weight_attrs(qweight, extra_weight_attrs)
|
||||
layer.register_parameter("g_idx", g_idx)
|
||||
set_weight_attrs(g_idx, extra_weight_attrs)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
set_weight_attrs(qzeros, extra_weight_attrs)
|
||||
layer.register_parameter("scales", scales)
|
||||
set_weight_attrs(scales, extra_weight_attrs)
|
||||
|
||||
layer.exllama_state = exllama_state
|
||||
|
||||
def apply_weights(self,
|
||||
weights: Dict[str, Any],
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
qweight = weights["qweight"]
|
||||
qweight = layer.qweight
|
||||
out_shape = x.shape[:-1] + (qweight.shape[-1], )
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
# exllama needs to shuffle the weight after the weight is loaded
|
||||
# here we do the shuffle on first forward pass
|
||||
if weights["exllama_state"] == ExllamaState.UNINITIALIZED:
|
||||
if layer.exllama_state == ExllamaState.UNINITIALIZED:
|
||||
if self.quant_config.desc_act:
|
||||
weights["g_idx"] = torch.argsort(weights["g_idx"]).to(
|
||||
torch.int)
|
||||
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
|
||||
else:
|
||||
weights["g_idx"] = torch.empty((1, 1), device="meta")
|
||||
weights["exllama_state"] = ExllamaState.READY
|
||||
ops.gptq_shuffle(weights["qweight"], weights["g_idx"],
|
||||
layer.g_idx.data = torch.empty((0, ),
|
||||
device=layer.g_idx.device)
|
||||
layer.exllama_state = ExllamaState.READY
|
||||
ops.gptq_shuffle(layer.qweight, layer.g_idx,
|
||||
self.quant_config.weight_bits)
|
||||
output = ops.gptq_gemm(reshaped_x, weights["qweight"],
|
||||
weights["qzeros"], weights["scales"],
|
||||
weights["g_idx"],
|
||||
weights["exllama_state"] == ExllamaState.READY,
|
||||
output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
|
||||
layer.scales, layer.g_idx,
|
||||
layer.exllama_state == ExllamaState.READY,
|
||||
self.quant_config.weight_bits)
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
output.add_(bias)
|
||||
return output.reshape(out_shape)
|
||||
|
@ -91,12 +91,14 @@ class MarlinLinearMethod(LinearMethodBase):
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_size_per_partition: int,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> Dict[str, Any]:
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del output_size # Unused.
|
||||
|
||||
if params_dtype != torch.float16:
|
||||
@ -187,21 +189,22 @@ class MarlinLinearMethod(LinearMethodBase):
|
||||
dtype=torch.int),
|
||||
requires_grad=False)
|
||||
|
||||
return {
|
||||
"B": qweight,
|
||||
"s": scales,
|
||||
"workspace": workspace,
|
||||
}
|
||||
layer.register_parameter("B", qweight)
|
||||
set_weight_attrs(qweight, extra_weight_attrs)
|
||||
layer.register_parameter("s", scales)
|
||||
set_weight_attrs(scales, extra_weight_attrs)
|
||||
layer.register_parameter("workspace", workspace)
|
||||
set_weight_attrs(workspace, extra_weight_attrs)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
weights: Dict[str, Any],
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
qweight = weights["B"]
|
||||
scales = weights["s"]
|
||||
workspace = weights["workspace"]
|
||||
qweight = layer.B
|
||||
scales = layer.s
|
||||
workspace = layer.workspace
|
||||
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
|
||||
|
@ -68,10 +68,11 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
|
||||
def __init__(self, quant_config: SqueezeLLMConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, input_size_per_partition: int,
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_size_per_partition: int, input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
if input_size_per_partition % self.quant_config.pack_factor != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
@ -103,17 +104,18 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
|
||||
set_weight_attrs(lookup_table, {
|
||||
"output_dim": 0,
|
||||
})
|
||||
return {
|
||||
"qweight": qweight,
|
||||
"lookup_table": lookup_table,
|
||||
}
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
set_weight_attrs(qweight, extra_weight_attrs)
|
||||
layer.register_parameter("lookup_table", lookup_table)
|
||||
set_weight_attrs(lookup_table, extra_weight_attrs)
|
||||
|
||||
def apply_weights(self,
|
||||
weights: Dict[str, Any],
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
qweight = weights["qweight"]
|
||||
lookup_table = weights["lookup_table"]
|
||||
qweight = layer.qweight
|
||||
lookup_table = layer.lookup_table
|
||||
out_shape = x.shape[:-1] + (qweight.shape[-1], )
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
if is_hip():
|
||||
@ -126,5 +128,5 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
|
||||
ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table)
|
||||
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
out.add_(bias)
|
||||
return out.reshape(out_shape)
|
||||
|
Loading…
x
Reference in New Issue
Block a user