[Core] Set linear_weights directly on the layer (#3977)

This commit is contained in:
Antoni Baum 2024-04-11 13:35:51 -07:00 committed by GitHub
parent 8afca50889
commit a10d3056da
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 114 additions and 102 deletions

View File

@ -2067,7 +2067,7 @@ void gptq_shuffle
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
vllm::gptq::shuffle_exllama_weight( vllm::gptq::shuffle_exllama_weight(
(uint32_t*) q_weight.data_ptr(), (uint32_t*) q_weight.data_ptr(),
q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(), q_perm.device().is_meta() || q_perm.numel() == 0 ? NULL : (int*) q_perm.data_ptr(),
q_weight.size(0) * 32 / bit, q_weight.size(0) * 32 / bit,
q_weight.size(1), q_weight.size(1),
bit bit

View File

@ -73,7 +73,7 @@ def test_mixtral_moe(dtype: torch.dtype):
).cuda() ).cuda()
# Load the weights # Load the weights
vllm_moe.gate.linear_weights["weight"][:] = hf_moe.gate.weight.data vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
for i in range(config.num_local_experts): for i in range(config.num_local_experts):
weights = (hf_moe.experts[i].w1.weight.data, weights = (hf_moe.experts[i].w1.weight.data,
hf_moe.experts[i].w3.weight.data) hf_moe.experts[i].w3.weight.data)

View File

@ -368,7 +368,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def apply_weights(self, x: torch.Tensor, def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights( output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x, bias) self.base_layer, x, bias)
_apply_lora( _apply_lora(
x, x,
self.lora_a_stacked, self.lora_a_stacked,
@ -402,10 +402,6 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
if self.base_layer.skip_bias_add else None) if self.base_layer.skip_bias_add else None)
return output, output_bias return output, output_bias
@property
def linear_weights(self):
return self.base_layer.linear_weights
@classmethod @classmethod
def can_replace_layer(cls, source_layer: nn.Module, def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List, lora_config: LoRAConfig, packed_modules_list: List,
@ -505,7 +501,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def apply_weights(self, x: torch.Tensor, def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights( output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x, bias) self.base_layer, x, bias)
_apply_lora_packed_nslice( _apply_lora_packed_nslice(
x, x,
self.lora_a_stacked, self.lora_a_stacked,
@ -746,7 +742,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
def apply_weights(self, x: torch.Tensor, def apply_weights(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights( output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x, bias) self.base_layer, x, bias)
_apply_lora_packed_nslice( _apply_lora_packed_nslice(
x, x,
self.lora_a_stacked, self.lora_a_stacked,
@ -838,7 +834,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def apply_weights(self, x: torch.Tensor) -> torch.Tensor: def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.linear_method.apply_weights( output = self.base_layer.linear_method.apply_weights(
self.base_layer.linear_weights, x) self.base_layer, x)
_apply_lora( _apply_lora(
x, x,
self.lora_a_stacked, self.lora_a_stacked,

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional from typing import List, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -28,19 +28,24 @@ class LinearMethodBase(ABC):
"""Base class for different (maybe quantized) linear methods.""" """Base class for different (maybe quantized) linear methods."""
@abstractmethod @abstractmethod
def create_weights(self, input_size_per_partition: int, def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_size_per_partition: int, input_size: int, output_size_per_partition: int, input_size: int,
output_size: int, output_size: int, params_dtype: torch.dtype,
params_dtype: torch.dtype) -> Dict[str, Any]: **extra_weight_attrs):
"""Create weights for a linear layer.""" """Create weights for a linear layer.
The weights will be set as attributes of the layer."""
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def apply_weights(self, def apply_weights(self,
weights: Dict[str, torch.Tensor], layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Apply the weights to the input tensor.""" """Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
raise NotImplementedError raise NotImplementedError
@ -55,22 +60,24 @@ class UnquantizedLinearMethod(LinearMethodBase):
def __init__(self, separate_bias_add: bool = False): def __init__(self, separate_bias_add: bool = False):
self.separate_bias_add = separate_bias_add self.separate_bias_add = separate_bias_add
def create_weights(self, input_size_per_partition: int, def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_size_per_partition: int, input_size: int, output_size_per_partition: int, input_size: int,
output_size: int, output_size: int, params_dtype: torch.dtype,
params_dtype: torch.dtype) -> Dict[str, Any]: **extra_weight_attrs):
weight = Parameter(torch.empty(output_size_per_partition, weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition, input_size_per_partition,
dtype=params_dtype), dtype=params_dtype),
requires_grad=False) requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
return {"weight": weight} layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
def apply_weights(self, def apply_weights(self,
weights: Dict[str, torch.Tensor], layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
weight = weights["weight"] weight = layer.weight
if self.separate_bias_add: if self.separate_bias_add:
if bias is not None: if bias is not None:
return F.linear(x, weight) + bias return F.linear(x, weight) + bias
@ -111,12 +118,9 @@ class ReplicatedLinear(torch.nn.Module):
if linear_method is None: if linear_method is None:
linear_method = UnquantizedLinearMethod() linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method self.linear_method = linear_method
self.linear_weights = self.linear_method.create_weights( self.linear_method.create_weights(self, self.input_size,
self.input_size, self.output_size, self.input_size, self.output_size, self.input_size,
self.output_size, self.params_dtype) self.output_size, self.params_dtype)
for name, weight in self.linear_weights.items():
if isinstance(weight, torch.Tensor):
self.register_parameter(name, weight)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.output_size, dtype=self.params_dtype)) torch.empty(self.output_size, dtype=self.params_dtype))
@ -126,7 +130,7 @@ class ReplicatedLinear(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
output = self.linear_method.apply_weights(self.linear_weights, x, bias) output = self.linear_method.apply_weights(self, x, bias)
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
return output, output_bias return output, output_bias
@ -177,13 +181,13 @@ class ColumnParallelLinear(torch.nn.Module):
if linear_method is None: if linear_method is None:
linear_method = UnquantizedLinearMethod() linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method self.linear_method = linear_method
self.linear_weights = self.linear_method.create_weights( self.linear_method.create_weights(self,
self.input_size, self.output_size_per_partition, self.input_size, self.input_size,
self.output_size, self.params_dtype) self.output_size_per_partition,
for name, weight in self.linear_weights.items(): self.input_size,
if isinstance(weight, torch.Tensor): self.output_size,
self.register_parameter(name, weight) self.params_dtype,
set_weight_attrs(weight, {"weight_loader": self.weight_loader}) weight_loader=self.weight_loader)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.output_size_per_partition, torch.empty(self.output_size_per_partition,
@ -211,8 +215,7 @@ class ColumnParallelLinear(torch.nn.Module):
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
# Matrix multiply. # Matrix multiply.
output_parallel = self.linear_method.apply_weights( output_parallel = self.linear_method.apply_weights(self, input_, bias)
self.linear_weights, input_, bias)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel) output = tensor_model_parallel_all_gather(output_parallel)
@ -523,13 +526,13 @@ class RowParallelLinear(torch.nn.Module):
if linear_method is None: if linear_method is None:
linear_method = UnquantizedLinearMethod() linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method self.linear_method = linear_method
self.linear_weights = self.linear_method.create_weights( self.linear_method.create_weights(self,
self.input_size_per_partition, self.output_size, self.input_size, self.input_size_per_partition,
self.output_size, self.params_dtype) self.output_size,
for name, weight in self.linear_weights.items(): self.input_size,
if isinstance(weight, torch.Tensor): self.output_size,
self.register_parameter(name, weight) self.params_dtype,
set_weight_attrs(weight, {"weight_loader": self.weight_loader}) weight_loader=self.weight_loader)
if not reduce_results and (bias and not skip_bias_add): if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the " raise ValueError("When not reduce the results, adding bias to the "
@ -569,7 +572,7 @@ class RowParallelLinear(torch.nn.Module):
# Matrix multiply. # Matrix multiply.
output_parallel = self.linear_method.apply_weights( output_parallel = self.linear_method.apply_weights(
self.linear_weights, input_parallel) self, input_parallel)
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel) output_ = tensor_model_parallel_all_reduce(output_parallel)
else: else:

View File

@ -79,10 +79,11 @@ class AWQLinearMethod(LinearMethodBase):
def __init__(self, quant_config: AWQConfig): def __init__(self, quant_config: AWQConfig):
self.quant_config = quant_config self.quant_config = quant_config
def create_weights(self, input_size_per_partition: int, def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_size_per_partition: int, input_size: int, output_size_per_partition: int, input_size: int,
output_size: int, output_size: int, params_dtype: torch.dtype,
params_dtype: torch.dtype) -> Dict[str, Any]: **extra_weight_attrs):
if input_size_per_partition % self.quant_config.group_size != 0: if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError( raise ValueError(
"The input size is not aligned with the quantized " "The input size is not aligned with the quantized "
@ -136,19 +137,21 @@ class AWQLinearMethod(LinearMethodBase):
"input_dim": 0, "input_dim": 0,
"output_dim": 1, "output_dim": 1,
}) })
return {
"qweight": qweight, layer.register_parameter("qweight", qweight)
"qzeros": qzeros, set_weight_attrs(qweight, extra_weight_attrs)
"scales": scales, layer.register_parameter("qzeros", qzeros)
} set_weight_attrs(qzeros, extra_weight_attrs)
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)
def apply_weights(self, def apply_weights(self,
weights: Dict[str, Any], layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = weights["qweight"] qweight = layer.qweight
scales = weights["scales"] scales = layer.scales
qzeros = weights["qzeros"] qzeros = layer.qzeros
pack_factor = self.quant_config.pack_factor pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, )) out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
reshaped_x = x.reshape(-1, x.shape[-1]) reshaped_x = x.reshape(-1, x.shape[-1])
@ -163,5 +166,5 @@ class AWQLinearMethod(LinearMethodBase):
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
pack_factor) pack_factor)
if bias is not None: if bias is not None:
out = out + bias out.add_(bias)
return out.reshape(out_shape) return out.reshape(out_shape)

View File

@ -89,12 +89,14 @@ class GPTQLinearMethod(LinearMethodBase):
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_size_per_partition: int, output_size_per_partition: int,
input_size: int, input_size: int,
output_size: int, output_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
) -> Dict[str, Any]: **extra_weight_attrs,
):
del output_size # Unused. del output_size # Unused.
if input_size_per_partition % self.quant_config.group_size != 0: if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError( raise ValueError(
@ -179,37 +181,40 @@ class GPTQLinearMethod(LinearMethodBase):
"input_dim": scale_and_zero_input_dim, "input_dim": scale_and_zero_input_dim,
"output_dim": 1, "output_dim": 1,
}) })
return {
"qweight": qweight, layer.register_parameter("qweight", qweight)
"g_idx": g_idx, set_weight_attrs(qweight, extra_weight_attrs)
"qzeros": qzeros, layer.register_parameter("g_idx", g_idx)
"scales": scales, set_weight_attrs(g_idx, extra_weight_attrs)
"exllama_state": exllama_state, layer.register_parameter("qzeros", qzeros)
} set_weight_attrs(qzeros, extra_weight_attrs)
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)
layer.exllama_state = exllama_state
def apply_weights(self, def apply_weights(self,
weights: Dict[str, Any], layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = weights["qweight"] qweight = layer.qweight
out_shape = x.shape[:-1] + (qweight.shape[-1], ) out_shape = x.shape[:-1] + (qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1]) reshaped_x = x.reshape(-1, x.shape[-1])
# exllama needs to shuffle the weight after the weight is loaded # exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass # here we do the shuffle on first forward pass
if weights["exllama_state"] == ExllamaState.UNINITIALIZED: if layer.exllama_state == ExllamaState.UNINITIALIZED:
if self.quant_config.desc_act: if self.quant_config.desc_act:
weights["g_idx"] = torch.argsort(weights["g_idx"]).to( layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
torch.int)
else: else:
weights["g_idx"] = torch.empty((1, 1), device="meta") layer.g_idx.data = torch.empty((0, ),
weights["exllama_state"] = ExllamaState.READY device=layer.g_idx.device)
ops.gptq_shuffle(weights["qweight"], weights["g_idx"], layer.exllama_state = ExllamaState.READY
ops.gptq_shuffle(layer.qweight, layer.g_idx,
self.quant_config.weight_bits) self.quant_config.weight_bits)
output = ops.gptq_gemm(reshaped_x, weights["qweight"], output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
weights["qzeros"], weights["scales"], layer.scales, layer.g_idx,
weights["g_idx"], layer.exllama_state == ExllamaState.READY,
weights["exllama_state"] == ExllamaState.READY,
self.quant_config.weight_bits) self.quant_config.weight_bits)
if bias is not None: if bias is not None:
output = output + bias output.add_(bias)
return output.reshape(out_shape) return output.reshape(out_shape)

View File

@ -91,12 +91,14 @@ class MarlinLinearMethod(LinearMethodBase):
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
output_size_per_partition: int, output_size_per_partition: int,
input_size: int, input_size: int,
output_size: int, output_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
) -> Dict[str, Any]: **extra_weight_attrs,
):
del output_size # Unused. del output_size # Unused.
if params_dtype != torch.float16: if params_dtype != torch.float16:
@ -187,21 +189,22 @@ class MarlinLinearMethod(LinearMethodBase):
dtype=torch.int), dtype=torch.int),
requires_grad=False) requires_grad=False)
return { layer.register_parameter("B", qweight)
"B": qweight, set_weight_attrs(qweight, extra_weight_attrs)
"s": scales, layer.register_parameter("s", scales)
"workspace": workspace, set_weight_attrs(scales, extra_weight_attrs)
} layer.register_parameter("workspace", workspace)
set_weight_attrs(workspace, extra_weight_attrs)
def apply_weights( def apply_weights(
self, self,
weights: Dict[str, Any], layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
qweight = weights["B"] qweight = layer.B
scales = weights["s"] scales = layer.s
workspace = weights["workspace"] workspace = layer.workspace
x_2d = x.view(-1, x.shape[-1]) x_2d = x.view(-1, x.shape[-1])

View File

@ -68,10 +68,11 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
def __init__(self, quant_config: SqueezeLLMConfig): def __init__(self, quant_config: SqueezeLLMConfig):
self.quant_config = quant_config self.quant_config = quant_config
def create_weights(self, input_size_per_partition: int, def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_size_per_partition: int, input_size: int, output_size_per_partition: int, input_size: int,
output_size: int, output_size: int, params_dtype: torch.dtype,
params_dtype: torch.dtype) -> Dict[str, Any]: **extra_weight_attrs):
if input_size_per_partition % self.quant_config.pack_factor != 0: if input_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError( raise ValueError(
"The input size is not aligned with the quantized " "The input size is not aligned with the quantized "
@ -103,17 +104,18 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
set_weight_attrs(lookup_table, { set_weight_attrs(lookup_table, {
"output_dim": 0, "output_dim": 0,
}) })
return {
"qweight": qweight, layer.register_parameter("qweight", qweight)
"lookup_table": lookup_table, set_weight_attrs(qweight, extra_weight_attrs)
} layer.register_parameter("lookup_table", lookup_table)
set_weight_attrs(lookup_table, extra_weight_attrs)
def apply_weights(self, def apply_weights(self,
weights: Dict[str, Any], layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = weights["qweight"] qweight = layer.qweight
lookup_table = weights["lookup_table"] lookup_table = layer.lookup_table
out_shape = x.shape[:-1] + (qweight.shape[-1], ) out_shape = x.shape[:-1] + (qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1]) reshaped_x = x.reshape(-1, x.shape[-1])
if is_hip(): if is_hip():
@ -126,5 +128,5 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table) ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table)
if bias is not None: if bias is not None:
out = out + bias out.add_(bias)
return out.reshape(out_shape) return out.reshape(out_shape)