2024-04-26 13:41:14 -07:00
|
|
|
from abc import abstractmethod
|
2024-06-01 13:51:10 -07:00
|
|
|
from typing import Dict, List, Optional, Tuple
|
2023-11-15 22:50:41 -08:00
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn.functional as F
|
2024-08-06 07:54:23 +08:00
|
|
|
from torch.nn.parameter import Parameter, UninitializedParameter
|
2023-11-15 22:50:41 -08:00
|
|
|
|
2024-04-10 15:33:30 -07:00
|
|
|
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
|
|
|
get_tensor_model_parallel_world_size,
|
|
|
|
split_tensor_along_last_dim,
|
|
|
|
tensor_model_parallel_all_gather,
|
|
|
|
tensor_model_parallel_all_reduce)
|
2024-03-25 23:59:47 +09:00
|
|
|
from vllm.logger import init_logger
|
2024-04-26 13:41:14 -07:00
|
|
|
from vllm.model_executor.layers.quantization.base_config import (
|
|
|
|
QuantizationConfig, QuantizeMethodBase)
|
2023-11-15 22:50:41 -08:00
|
|
|
from vllm.model_executor.utils import set_weight_attrs
|
|
|
|
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
|
|
|
2024-03-01 14:47:51 -06:00
|
|
|
def adjust_marlin_shard(param, shard_size, shard_offset):
|
|
|
|
marlin_tile_size = getattr(param, "marlin_tile_size", None)
|
|
|
|
if marlin_tile_size is None:
|
|
|
|
return shard_size, shard_offset
|
|
|
|
|
|
|
|
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
|
|
|
|
|
|
|
|
|
2024-06-01 13:51:10 -07:00
|
|
|
def adjust_bitsandbytes_shard(param: Parameter,
|
|
|
|
qkv_offsets: Dict[str, Tuple[int, int]],
|
|
|
|
loaded_shard_id: str) -> Tuple[int, int]:
|
|
|
|
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
|
|
|
|
|
|
|
|
total, _ = qkv_offsets["total"]
|
|
|
|
orig_offset, orig_size = qkv_offsets[loaded_shard_id]
|
|
|
|
|
|
|
|
quantized_total = param.data.shape[0]
|
|
|
|
quantized_offset = orig_offset * quantized_total // total
|
|
|
|
quantized_size = orig_size * quantized_total // total
|
|
|
|
|
|
|
|
return quantized_size, quantized_offset
|
|
|
|
|
|
|
|
|
2024-06-30 19:06:27 -04:00
|
|
|
def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
|
|
|
|
"""For fused modules (QKV and MLP) we have an array of length
|
|
|
|
N that holds 1 scale for each "logical" matrix. So the param
|
|
|
|
is an array of length N. The loaded_weight corresponds to
|
|
|
|
one of the shards on disk. Here, we slice the param based on
|
|
|
|
the shard_id for loading.
|
|
|
|
"""
|
|
|
|
qkv_idxs = {"q": 0, "k": 1, "v": 2}
|
|
|
|
|
|
|
|
if isinstance(shard_id, str):
|
|
|
|
shard_id = qkv_idxs[shard_id]
|
|
|
|
elif not isinstance(shard_id, int):
|
|
|
|
raise ValueError(f"Unknown Shard Id {shard_id}")
|
|
|
|
|
|
|
|
# AutoFP8 scales do not have a shape
|
|
|
|
# compressed-tensors scales do have a shape
|
|
|
|
if len(loaded_weight.shape) != 0:
|
|
|
|
assert loaded_weight.shape[0] == 1
|
|
|
|
loaded_weight = loaded_weight[0]
|
|
|
|
|
|
|
|
return param[shard_id], loaded_weight
|
|
|
|
|
|
|
|
|
2024-04-26 13:41:14 -07:00
|
|
|
class LinearMethodBase(QuantizeMethodBase):
|
2023-11-15 22:50:41 -08:00
|
|
|
"""Base class for different (maybe quantized) linear methods."""
|
|
|
|
|
|
|
|
@abstractmethod
|
2024-04-11 13:35:51 -07:00
|
|
|
def create_weights(self, layer: torch.nn.Module,
|
|
|
|
input_size_per_partition: int,
|
2024-04-23 13:59:33 -04:00
|
|
|
output_partition_sizes: List[int], input_size: int,
|
2024-04-11 13:35:51 -07:00
|
|
|
output_size: int, params_dtype: torch.dtype,
|
|
|
|
**extra_weight_attrs):
|
2024-04-23 21:26:33 -04:00
|
|
|
"""Create weights for a linear layer.
|
|
|
|
The weights will be set as attributes of the layer.
|
2024-06-01 13:51:10 -07:00
|
|
|
|
2024-04-23 21:26:33 -04:00
|
|
|
Args:
|
|
|
|
layer: The layer that is using the LinearMethodBase factory.
|
|
|
|
input_size_per_partition: Size of the weight input dim on rank X.
|
|
|
|
output_partition_sizes: Sizes of the output dim of each logical
|
|
|
|
weight on rank X. E.g., output_partition_sizes for QKVLinear
|
|
|
|
is a list contains the width of Wq, Wk, Wv on rank X.
|
|
|
|
input_size: Size of the input dim of the weight across all ranks.
|
|
|
|
output_size: Size of the output dim of the weight across all ranks.
|
|
|
|
params_dtype: Datatype of the parameters.
|
|
|
|
"""
|
2023-11-15 22:50:41 -08:00
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
@abstractmethod
|
2024-04-26 13:41:14 -07:00
|
|
|
def apply(self,
|
|
|
|
layer: torch.nn.Module,
|
|
|
|
x: torch.Tensor,
|
|
|
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
2024-04-11 13:35:51 -07:00
|
|
|
"""Apply the weights in layer to the input tensor.
|
|
|
|
Expects create_weights to have been called before on the layer."""
|
2023-11-15 22:50:41 -08:00
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
class UnquantizedLinearMethod(LinearMethodBase):
|
2024-07-12 01:06:09 -04:00
|
|
|
"""Linear method without quantization."""
|
2023-11-15 22:50:41 -08:00
|
|
|
|
2024-04-11 13:35:51 -07:00
|
|
|
def create_weights(self, layer: torch.nn.Module,
|
|
|
|
input_size_per_partition: int,
|
2024-04-23 13:59:33 -04:00
|
|
|
output_partition_sizes: List[int], input_size: int,
|
2024-04-11 13:35:51 -07:00
|
|
|
output_size: int, params_dtype: torch.dtype,
|
|
|
|
**extra_weight_attrs):
|
2024-05-23 17:29:18 -04:00
|
|
|
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
2023-12-15 19:04:22 +08:00
|
|
|
input_size_per_partition,
|
2023-11-15 22:50:41 -08:00
|
|
|
dtype=params_dtype),
|
|
|
|
requires_grad=False)
|
|
|
|
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
2024-04-11 13:35:51 -07:00
|
|
|
layer.register_parameter("weight", weight)
|
|
|
|
set_weight_attrs(weight, extra_weight_attrs)
|
2023-11-15 22:50:41 -08:00
|
|
|
|
2024-04-26 13:41:14 -07:00
|
|
|
def apply(self,
|
|
|
|
layer: torch.nn.Module,
|
|
|
|
x: torch.Tensor,
|
|
|
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
2024-07-12 01:06:09 -04:00
|
|
|
|
|
|
|
return F.linear(x, layer.weight, bias)
|
2023-11-15 22:50:41 -08:00
|
|
|
|
|
|
|
|
2024-04-26 13:41:14 -07:00
|
|
|
class LinearBase(torch.nn.Module):
|
|
|
|
"""Base linear layer.
|
2023-11-15 22:50:41 -08:00
|
|
|
|
|
|
|
Args:
|
|
|
|
input_size: input dimension of the linear layer.
|
|
|
|
output_size: output dimension of the linear layer.
|
|
|
|
bias: If true, add bias.
|
|
|
|
skip_bias_add: If true, skip adding bias but instead return it.
|
|
|
|
params_dtype: Data type for the parameters.
|
2024-04-26 13:41:14 -07:00
|
|
|
quant_config: Quantization configure.
|
2023-11-15 22:50:41 -08:00
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
input_size: int,
|
|
|
|
output_size: int,
|
|
|
|
skip_bias_add: bool = False,
|
|
|
|
params_dtype: Optional[torch.dtype] = None,
|
2024-04-26 13:41:14 -07:00
|
|
|
quant_config: Optional[QuantizationConfig] = None,
|
2024-07-20 12:36:57 -04:00
|
|
|
prefix: str = "",
|
2023-11-15 22:50:41 -08:00
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
# Keep input parameters
|
|
|
|
self.input_size = input_size
|
|
|
|
self.output_size = output_size
|
|
|
|
self.skip_bias_add = skip_bias_add
|
|
|
|
if params_dtype is None:
|
|
|
|
params_dtype = torch.get_default_dtype()
|
|
|
|
self.params_dtype = params_dtype
|
2024-04-26 13:41:14 -07:00
|
|
|
if quant_config is None:
|
2024-04-29 11:01:26 +09:00
|
|
|
self.quant_method: Optional[
|
|
|
|
QuantizeMethodBase] = UnquantizedLinearMethod()
|
2024-04-26 13:41:14 -07:00
|
|
|
else:
|
2024-07-20 12:36:57 -04:00
|
|
|
self.quant_method = quant_config.get_quant_method(self,
|
|
|
|
prefix=prefix)
|
2024-04-26 13:41:14 -07:00
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
class ReplicatedLinear(LinearBase):
|
|
|
|
"""Replicated linear layer.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
input_size: input dimension of the linear layer.
|
|
|
|
output_size: output dimension of the linear layer.
|
|
|
|
bias: If true, add bias.
|
|
|
|
skip_bias_add: If true, skip adding bias but instead return it.
|
|
|
|
params_dtype: Data type for the parameters.
|
|
|
|
quant_config: Quantization configure.
|
2024-07-18 22:39:18 -04:00
|
|
|
prefix: The name of the layer in the state dict, including all parents
|
|
|
|
(e.g. model.layers.0.qkv_proj)
|
2024-04-26 13:41:14 -07:00
|
|
|
"""
|
|
|
|
|
2024-05-23 17:29:18 -04:00
|
|
|
def __init__(self,
|
|
|
|
input_size: int,
|
|
|
|
output_size: int,
|
|
|
|
bias: bool = True,
|
|
|
|
skip_bias_add: bool = False,
|
|
|
|
params_dtype: Optional[torch.dtype] = None,
|
2024-07-18 22:39:18 -04:00
|
|
|
quant_config: Optional[QuantizationConfig] = None,
|
2024-07-20 12:36:57 -04:00
|
|
|
prefix: str = ""):
|
|
|
|
super().__init__(input_size,
|
|
|
|
output_size,
|
|
|
|
skip_bias_add,
|
|
|
|
params_dtype,
|
|
|
|
quant_config,
|
|
|
|
prefix=prefix)
|
2024-04-26 13:41:14 -07:00
|
|
|
|
2024-04-29 11:01:26 +09:00
|
|
|
# All the linear layer supports quant method.
|
|
|
|
assert self.quant_method is not None
|
2024-07-18 22:39:18 -04:00
|
|
|
self.quant_method.create_weights(self,
|
|
|
|
self.input_size, [self.output_size],
|
|
|
|
self.input_size,
|
|
|
|
self.output_size,
|
|
|
|
self.params_dtype,
|
2024-07-25 19:24:58 -07:00
|
|
|
weight_loader=self.weight_loader,
|
2024-07-18 22:39:18 -04:00
|
|
|
prefix=prefix)
|
2024-04-26 13:41:14 -07:00
|
|
|
|
2023-11-15 22:50:41 -08:00
|
|
|
if bias:
|
|
|
|
self.bias = Parameter(
|
2024-02-02 07:46:39 +08:00
|
|
|
torch.empty(self.output_size, dtype=self.params_dtype))
|
2024-07-25 19:24:58 -07:00
|
|
|
set_weight_attrs(self.bias, {
|
|
|
|
"output_dim": 0,
|
|
|
|
"weight_loader": self.weight_loader,
|
|
|
|
})
|
2023-11-15 22:50:41 -08:00
|
|
|
else:
|
|
|
|
self.register_parameter("bias", None)
|
|
|
|
|
2024-07-16 18:31:32 -04:00
|
|
|
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
|
|
|
# If the weight on disk does not have a shape, give it one
|
|
|
|
# (such scales for AutoFp8).
|
|
|
|
if len(loaded_weight.shape) == 0:
|
|
|
|
loaded_weight = loaded_weight.reshape(1)
|
|
|
|
|
|
|
|
assert param.size() == loaded_weight.size()
|
|
|
|
param.data.copy_(loaded_weight)
|
|
|
|
|
2023-11-15 22:50:41 -08:00
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
bias = self.bias if not self.skip_bias_add else None
|
2024-04-29 11:01:26 +09:00
|
|
|
assert self.quant_method is not None
|
2024-04-26 13:41:14 -07:00
|
|
|
output = self.quant_method.apply(self, x, bias)
|
2023-11-15 22:50:41 -08:00
|
|
|
output_bias = self.bias if self.skip_bias_add else None
|
|
|
|
return output, output_bias
|
|
|
|
|
2024-05-01 12:18:14 +08:00
|
|
|
def extra_repr(self) -> str:
|
|
|
|
s = f"in_features={self.input_size}"
|
|
|
|
s += f", output_features={self.output_size}"
|
|
|
|
s += f", bias={self.bias is not None}"
|
|
|
|
return s
|
|
|
|
|
2023-11-15 22:50:41 -08:00
|
|
|
|
2024-04-26 13:41:14 -07:00
|
|
|
class ColumnParallelLinear(LinearBase):
|
2023-11-15 22:50:41 -08:00
|
|
|
"""Linear layer with column parallelism.
|
|
|
|
|
|
|
|
The linear layer is defined as Y = XA + b. A is parallelized along
|
|
|
|
its second dimension as A = [A_1, ..., A_p].
|
|
|
|
|
|
|
|
Args:
|
|
|
|
input_size: first dimension of matrix A.
|
|
|
|
output_size: second dimension of matrix A.
|
|
|
|
bias: If true, add bias.
|
|
|
|
gather_output: If true, call all-gather on output and make Y available
|
|
|
|
to all GPUs, otherwise, every GPU will have its output
|
|
|
|
which is Y_i = XA_i
|
|
|
|
skip_bias_add: This was added to enable performance optimizations where
|
|
|
|
bias can be fused with other element-wise operations. we
|
|
|
|
skip adding bias but instead return it.
|
|
|
|
params_dtype: Data type for the parameters.
|
2024-04-26 13:41:14 -07:00
|
|
|
quant_config: Quantization configure.
|
2024-04-23 13:59:33 -04:00
|
|
|
output_sizes: list of output sizes packed into one output, like for QKV
|
|
|
|
the list would be size 3.
|
2024-07-18 22:39:18 -04:00
|
|
|
prefix: The name of the layer in the state dict, including all parents
|
|
|
|
(e.g. model.layers.0.qkv_proj)
|
2023-11-15 22:50:41 -08:00
|
|
|
"""
|
|
|
|
|
2024-05-23 17:29:18 -04:00
|
|
|
def __init__(self,
|
|
|
|
input_size: int,
|
|
|
|
output_size: int,
|
|
|
|
bias: bool = True,
|
|
|
|
gather_output: bool = False,
|
|
|
|
skip_bias_add: bool = False,
|
|
|
|
params_dtype: Optional[torch.dtype] = None,
|
|
|
|
quant_config: Optional[QuantizationConfig] = None,
|
2024-07-18 22:39:18 -04:00
|
|
|
output_sizes: Optional[List[int]] = None,
|
2024-07-20 12:36:57 -04:00
|
|
|
prefix: str = ""):
|
2024-04-26 13:41:14 -07:00
|
|
|
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
2024-07-20 12:36:57 -04:00
|
|
|
quant_config, prefix)
|
2023-11-15 22:50:41 -08:00
|
|
|
|
|
|
|
self.gather_output = gather_output
|
2024-04-26 13:41:14 -07:00
|
|
|
|
2023-11-15 22:50:41 -08:00
|
|
|
# Divide the weight matrix along the last dimension.
|
|
|
|
tp_size = get_tensor_model_parallel_world_size()
|
2024-05-23 17:29:18 -04:00
|
|
|
assert self.quant_method is not None
|
|
|
|
self.output_size_per_partition = divide(self.output_size, tp_size)
|
|
|
|
self.output_partition_sizes = [self.output_size_per_partition]
|
|
|
|
# If QKV or MergedColumn, use output size of each partition.
|
|
|
|
if hasattr(self, "output_sizes"):
|
|
|
|
self.output_partition_sizes = [
|
|
|
|
divide(output_size, tp_size)
|
|
|
|
for output_size in self.output_sizes
|
|
|
|
]
|
|
|
|
|
2024-04-23 13:59:33 -04:00
|
|
|
if output_sizes is None:
|
|
|
|
output_sizes = [output_size]
|
2024-05-23 17:29:18 -04:00
|
|
|
self.quant_method.create_weights(
|
|
|
|
layer=self,
|
|
|
|
input_size_per_partition=self.input_size,
|
|
|
|
output_partition_sizes=self.output_partition_sizes,
|
|
|
|
input_size=self.input_size,
|
|
|
|
output_size=self.output_size,
|
|
|
|
params_dtype=self.params_dtype,
|
2024-07-18 22:39:18 -04:00
|
|
|
weight_loader=self.weight_loader,
|
|
|
|
prefix=prefix)
|
2023-11-15 22:50:41 -08:00
|
|
|
if bias:
|
|
|
|
self.bias = Parameter(
|
|
|
|
torch.empty(self.output_size_per_partition,
|
|
|
|
dtype=params_dtype))
|
|
|
|
set_weight_attrs(self.bias, {
|
|
|
|
"output_dim": 0,
|
|
|
|
"weight_loader": self.weight_loader,
|
|
|
|
})
|
|
|
|
else:
|
|
|
|
self.register_parameter("bias", None)
|
|
|
|
|
|
|
|
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
|
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
|
|
output_dim = getattr(param, "output_dim", None)
|
2024-08-06 07:54:23 +08:00
|
|
|
|
|
|
|
# Special case for GGUF
|
|
|
|
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
|
|
|
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
|
|
|
if is_gguf_weight_type:
|
|
|
|
param.weight_type = loaded_weight.item()
|
|
|
|
|
|
|
|
# Materialize GGUF UninitializedParameter
|
|
|
|
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
|
|
|
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
|
|
|
|
|
2023-11-15 22:50:41 -08:00
|
|
|
param_data = param.data
|
|
|
|
if output_dim is not None:
|
|
|
|
shard_size = param_data.shape[output_dim]
|
|
|
|
start_idx = tp_rank * shard_size
|
|
|
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
|
|
|
shard_size)
|
2024-06-28 13:49:57 -04:00
|
|
|
|
|
|
|
# Special case for loading scales off disk, which often do not
|
|
|
|
# have a shape (such as in the case of AutoFP8).
|
|
|
|
if len(loaded_weight.shape) == 0:
|
|
|
|
loaded_weight = loaded_weight.reshape(1)
|
2024-04-30 17:46:12 -04:00
|
|
|
|
2023-11-15 22:50:41 -08:00
|
|
|
assert param_data.shape == loaded_weight.shape
|
|
|
|
param_data.copy_(loaded_weight)
|
|
|
|
|
|
|
|
def forward(self, input_):
|
|
|
|
bias = self.bias if not self.skip_bias_add else None
|
|
|
|
|
|
|
|
# Matrix multiply.
|
2024-04-29 11:01:26 +09:00
|
|
|
assert self.quant_method is not None
|
2024-04-26 13:41:14 -07:00
|
|
|
output_parallel = self.quant_method.apply(self, input_, bias)
|
2023-11-15 22:50:41 -08:00
|
|
|
if self.gather_output:
|
|
|
|
# All-gather across the partitions.
|
|
|
|
output = tensor_model_parallel_all_gather(output_parallel)
|
|
|
|
else:
|
|
|
|
output = output_parallel
|
|
|
|
output_bias = self.bias if self.skip_bias_add else None
|
|
|
|
return output, output_bias
|
|
|
|
|
2024-05-01 12:18:14 +08:00
|
|
|
def extra_repr(self) -> str:
|
|
|
|
s = f"in_features={self.input_size}"
|
|
|
|
s += f", output_features={self.output_size_per_partition}"
|
|
|
|
s += f", bias={self.bias is not None}"
|
|
|
|
s += f", tp_size={get_tensor_model_parallel_world_size()}"
|
|
|
|
s += f", gather_output={self.gather_output}"
|
|
|
|
return s
|
|
|
|
|
2023-11-15 22:50:41 -08:00
|
|
|
|
|
|
|
class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
|
|
"""Packed linear layers with column parallelism.
|
|
|
|
|
|
|
|
Similar to ColumnParallelLinear, but the weight matrix is concatenated
|
|
|
|
along the output dimension. When the weight matrix is loaded, the
|
|
|
|
different partitions are sharded separately.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
input_size: input dimension of the linear layer.
|
|
|
|
output_sizes: list of output dimensions of the linear layer.
|
|
|
|
bias: If true, add bias.
|
|
|
|
gather_output: If true, call all-gather on output and make the output
|
|
|
|
available to all GPUs, otherwise, every GPU will have
|
|
|
|
its own output.
|
|
|
|
skip_bias_add: This was added to enable performance optimizations where
|
|
|
|
bias can be fused with other element-wise operations. we
|
|
|
|
skip adding bias but instead return it.
|
|
|
|
params_dtype: Data type for the parameters.
|
2024-04-26 13:41:14 -07:00
|
|
|
quant_config: Quantization configure.
|
2024-07-18 22:39:18 -04:00
|
|
|
prefix: The name of the layer in the state dict, including all parents
|
|
|
|
(e.g. model.layers.0.qkv_proj)
|
2023-11-15 22:50:41 -08:00
|
|
|
"""
|
|
|
|
|
2024-05-23 17:29:18 -04:00
|
|
|
def __init__(self,
|
|
|
|
input_size: int,
|
|
|
|
output_sizes: List[int],
|
|
|
|
bias: bool = True,
|
|
|
|
gather_output: bool = False,
|
|
|
|
skip_bias_add: bool = False,
|
|
|
|
params_dtype: Optional[torch.dtype] = None,
|
2024-07-18 22:39:18 -04:00
|
|
|
quant_config: Optional[QuantizationConfig] = None,
|
2024-07-20 12:36:57 -04:00
|
|
|
prefix: str = ""):
|
2023-11-15 22:50:41 -08:00
|
|
|
self.output_sizes = output_sizes
|
|
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
|
|
assert all(output_size % tp_size == 0 for output_size in output_sizes)
|
2024-05-23 17:29:18 -04:00
|
|
|
super().__init__(input_size=input_size,
|
|
|
|
output_size=sum(output_sizes),
|
|
|
|
bias=bias,
|
|
|
|
gather_output=gather_output,
|
|
|
|
skip_bias_add=skip_bias_add,
|
|
|
|
params_dtype=params_dtype,
|
2024-07-18 22:39:18 -04:00
|
|
|
quant_config=quant_config,
|
|
|
|
prefix=prefix)
|
2023-11-15 22:50:41 -08:00
|
|
|
|
|
|
|
def weight_loader(self,
|
|
|
|
param: Parameter,
|
|
|
|
loaded_weight: torch.Tensor,
|
|
|
|
loaded_shard_id: Optional[int] = None):
|
2024-04-23 13:59:33 -04:00
|
|
|
|
2024-08-06 07:54:23 +08:00
|
|
|
# Special case for GGUF
|
|
|
|
# initialize GGUF param after we know the quantize type
|
|
|
|
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
|
|
|
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
|
|
|
if is_gguf_weight_type:
|
|
|
|
param.data[loaded_shard_id].copy_(loaded_weight)
|
|
|
|
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
|
|
|
|
return
|
|
|
|
|
|
|
|
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
|
|
|
from gguf.constants import GGML_QUANT_SIZES
|
|
|
|
|
|
|
|
ori_shape = param.tensor_shape
|
|
|
|
weight_types = self.qweight_type.shard_weight_type.values()
|
|
|
|
row_size = []
|
|
|
|
for weight_type in weight_types:
|
|
|
|
block_size, type_size = GGML_QUANT_SIZES[weight_type]
|
|
|
|
row_size.append(ori_shape[1] // block_size * type_size)
|
|
|
|
q_shape = (ori_shape[0], max(row_size))
|
|
|
|
param.materialize(q_shape, dtype=loaded_weight.dtype)
|
|
|
|
|
2023-11-15 22:50:41 -08:00
|
|
|
param_data = param.data
|
|
|
|
output_dim = getattr(param, "output_dim", None)
|
2024-04-30 17:46:12 -04:00
|
|
|
# Special case for AQLM codebooks.
|
2024-04-23 13:59:33 -04:00
|
|
|
is_metadata = getattr(param, "is_metadata", False)
|
2024-06-30 19:06:27 -04:00
|
|
|
# Special case for per-tensor scale to load scalar into fused array.
|
|
|
|
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
|
2024-04-30 17:46:12 -04:00
|
|
|
|
2023-11-15 22:50:41 -08:00
|
|
|
if loaded_shard_id is None:
|
2024-06-28 14:43:49 -04:00
|
|
|
# Loaded weight is already fused on disk (qkv/mlp).
|
2023-11-15 22:50:41 -08:00
|
|
|
if output_dim is None:
|
2024-07-10 07:43:24 +08:00
|
|
|
if needs_scalar_to_array:
|
2024-06-30 19:06:27 -04:00
|
|
|
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
|
|
|
param_data, loaded_weight, 0)
|
2024-06-28 14:43:49 -04:00
|
|
|
|
2023-11-15 22:50:41 -08:00
|
|
|
assert param_data.shape == loaded_weight.shape
|
|
|
|
param_data.copy_(loaded_weight)
|
|
|
|
return
|
|
|
|
current_shard_offset = 0
|
2024-06-15 12:45:31 +08:00
|
|
|
shard_offsets: List[Tuple[int, int, int]] = []
|
2023-11-15 22:50:41 -08:00
|
|
|
for i, output_size in enumerate(self.output_sizes):
|
|
|
|
shard_offsets.append((i, current_shard_offset, output_size))
|
|
|
|
current_shard_offset += output_size
|
|
|
|
packed_dim = getattr(param, "packed_dim", None)
|
|
|
|
for shard_id, shard_offset, shard_size in shard_offsets:
|
2024-04-30 17:46:12 -04:00
|
|
|
# Special case for Quantization.
|
2023-11-15 22:50:41 -08:00
|
|
|
# If quantized, we need to adjust the offset and size to account
|
|
|
|
# for the packing.
|
|
|
|
if packed_dim == output_dim:
|
|
|
|
shard_size = shard_size // param.pack_factor
|
|
|
|
shard_offset = shard_offset // param.pack_factor
|
2024-04-30 17:46:12 -04:00
|
|
|
# Special case for Marlin.
|
2024-03-01 14:47:51 -06:00
|
|
|
shard_size, shard_offset = adjust_marlin_shard(
|
|
|
|
param, shard_size, shard_offset)
|
|
|
|
|
2023-11-15 22:50:41 -08:00
|
|
|
loaded_weight_shard = loaded_weight.narrow(
|
|
|
|
output_dim, shard_offset, shard_size)
|
|
|
|
self.weight_loader(param, loaded_weight_shard, shard_id)
|
|
|
|
return
|
|
|
|
|
|
|
|
assert loaded_shard_id < len(self.output_sizes)
|
|
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
|
|
if output_dim is not None:
|
|
|
|
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
|
|
|
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
2024-04-30 17:46:12 -04:00
|
|
|
# Special case for quantization.
|
2023-11-15 22:50:41 -08:00
|
|
|
# If quantized, we need to adjust the offset and size to account
|
|
|
|
# for the packing.
|
|
|
|
packed_dim = getattr(param, "packed_dim", None)
|
|
|
|
if packed_dim == output_dim:
|
|
|
|
shard_size = shard_size // param.pack_factor
|
|
|
|
shard_offset = shard_offset // param.pack_factor
|
2024-04-30 17:46:12 -04:00
|
|
|
# Special case for Marlin.
|
2024-03-01 14:47:51 -06:00
|
|
|
shard_size, shard_offset = adjust_marlin_shard(
|
|
|
|
param, shard_size, shard_offset)
|
|
|
|
|
2024-06-01 13:51:10 -07:00
|
|
|
use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
|
|
|
|
if use_bitsandbytes:
|
|
|
|
shard_size = loaded_weight.shape[output_dim]
|
|
|
|
shard_offset = loaded_weight.shape[output_dim] * \
|
|
|
|
loaded_shard_id
|
|
|
|
|
2024-08-06 07:54:23 +08:00
|
|
|
if is_gguf_weight:
|
|
|
|
shard_size = loaded_weight.shape[output_dim]
|
|
|
|
shard_offset = loaded_weight.shape[output_dim] * \
|
|
|
|
loaded_shard_id
|
|
|
|
param.shard_id.append(loaded_shard_id)
|
|
|
|
param.shard_size[loaded_shard_id] = loaded_weight.shape
|
|
|
|
|
2023-11-15 22:50:41 -08:00
|
|
|
param_data = param_data.narrow(output_dim, shard_offset,
|
|
|
|
shard_size)
|
|
|
|
start_idx = tp_rank * shard_size
|
|
|
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
|
|
|
shard_size)
|
2024-04-30 17:46:12 -04:00
|
|
|
# Special case for AQLM codebooks.
|
2024-04-23 13:59:33 -04:00
|
|
|
elif is_metadata:
|
|
|
|
# metadata indicates fixed size concatenated along dim 0
|
|
|
|
shard_size = loaded_weight.shape[0]
|
|
|
|
shard_offset = loaded_shard_id * shard_size
|
|
|
|
param_data = param_data.narrow(0, shard_offset, shard_size)
|
2024-05-23 17:29:18 -04:00
|
|
|
|
2024-06-30 19:06:27 -04:00
|
|
|
# Special case for per-tensor scales in fused case.
|
|
|
|
elif needs_scalar_to_array:
|
|
|
|
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
2024-04-30 17:46:12 -04:00
|
|
|
param_data, loaded_weight, loaded_shard_id)
|
|
|
|
|
2023-11-15 22:50:41 -08:00
|
|
|
else:
|
2023-12-15 19:04:22 +08:00
|
|
|
ignore_warning = getattr(param, "ignore_warning", False)
|
|
|
|
if not ignore_warning:
|
|
|
|
logger.warning(
|
|
|
|
"Loading a weight without `output_dim` attribute in "
|
|
|
|
"MergedColumnParallelLinear, assume the weight is "
|
|
|
|
"the same for all partitions.")
|
2024-05-23 17:29:18 -04:00
|
|
|
|
2023-11-15 22:50:41 -08:00
|
|
|
assert param_data.shape == loaded_weight.shape
|
|
|
|
param_data.copy_(loaded_weight)
|
|
|
|
|
|
|
|
|
|
|
|
class QKVParallelLinear(ColumnParallelLinear):
|
|
|
|
"""Linear layers for the attention's QKV transformation.
|
|
|
|
|
|
|
|
Linear layers for the linear transformation of the query, key, and value
|
|
|
|
vectors in the attention layer. The weight matrix is concatenated along
|
|
|
|
the output dimension. The layer is parallelized along the head dimension.
|
|
|
|
When the number of key/value heads is smaller than the number of query
|
|
|
|
heads (e.g., multi-query/grouped-query attention), the key/value head may
|
|
|
|
be replicated while the query heads are partitioned.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
hidden_size: input hidden state size of the transformer.
|
|
|
|
head_size: size of each attention head.
|
|
|
|
total_num_heads: total number of attention query heads.
|
|
|
|
total_num_kv_heads: total number of attention key/value heads. If
|
|
|
|
None, assume total_num_kv_heads = total_num_heads.
|
|
|
|
bias: If true, add bias.
|
|
|
|
skip_bias_add: This was added to enable performance optimizations where
|
|
|
|
bias can be fused with other element-wise operations. we
|
|
|
|
skip adding bias but instead return it.
|
|
|
|
params_dtype: Data type for the parameters.
|
2024-04-26 13:41:14 -07:00
|
|
|
quant_config: Quantization configure.
|
2024-07-18 22:39:18 -04:00
|
|
|
prefix: The name of the layer in the state dict, including all parents
|
|
|
|
(e.g. model.layers.0.qkv_proj)
|
2023-11-15 22:50:41 -08:00
|
|
|
"""
|
|
|
|
|
2024-05-23 17:29:18 -04:00
|
|
|
def __init__(self,
|
|
|
|
hidden_size: int,
|
|
|
|
head_size: int,
|
|
|
|
total_num_heads: int,
|
|
|
|
total_num_kv_heads: Optional[int] = None,
|
|
|
|
bias: bool = True,
|
|
|
|
skip_bias_add: bool = False,
|
|
|
|
params_dtype: Optional[torch.dtype] = None,
|
2024-07-18 22:39:18 -04:00
|
|
|
quant_config: Optional[QuantizationConfig] = None,
|
2024-07-20 12:36:57 -04:00
|
|
|
prefix: str = ""):
|
2023-11-15 22:50:41 -08:00
|
|
|
self.hidden_size = hidden_size
|
|
|
|
self.head_size = head_size
|
|
|
|
self.total_num_heads = total_num_heads
|
|
|
|
if total_num_kv_heads is None:
|
|
|
|
total_num_kv_heads = total_num_heads
|
|
|
|
self.total_num_kv_heads = total_num_kv_heads
|
|
|
|
# Divide the weight matrix along the last dimension.
|
|
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
|
|
self.num_heads = divide(self.total_num_heads, tp_size)
|
|
|
|
if tp_size >= self.total_num_kv_heads:
|
|
|
|
self.num_kv_heads = 1
|
|
|
|
self.num_kv_head_replicas = divide(tp_size,
|
|
|
|
self.total_num_kv_heads)
|
|
|
|
else:
|
|
|
|
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
|
|
|
|
self.num_kv_head_replicas = 1
|
|
|
|
input_size = self.hidden_size
|
|
|
|
output_size = (self.num_heads +
|
|
|
|
2 * self.num_kv_heads) * tp_size * self.head_size
|
2024-05-23 17:29:18 -04:00
|
|
|
self.output_sizes = [
|
|
|
|
self.num_heads * self.head_size * tp_size, # q_proj
|
|
|
|
self.num_kv_heads * self.head_size * tp_size, # k_proj
|
|
|
|
self.num_kv_heads * self.head_size * tp_size, # v_proj
|
2024-04-23 13:59:33 -04:00
|
|
|
]
|
|
|
|
|
2024-05-23 17:29:18 -04:00
|
|
|
super().__init__(input_size=input_size,
|
|
|
|
output_size=output_size,
|
|
|
|
bias=bias,
|
|
|
|
gather_output=False,
|
|
|
|
skip_bias_add=skip_bias_add,
|
|
|
|
params_dtype=params_dtype,
|
2024-07-18 22:39:18 -04:00
|
|
|
quant_config=quant_config,
|
|
|
|
prefix=prefix)
|
2023-11-15 22:50:41 -08:00
|
|
|
|
|
|
|
def weight_loader(self,
|
|
|
|
param: Parameter,
|
|
|
|
loaded_weight: torch.Tensor,
|
|
|
|
loaded_shard_id: Optional[str] = None):
|
2024-08-06 07:54:23 +08:00
|
|
|
|
|
|
|
# Special case for GGUF
|
|
|
|
# initialize GGUF param after we know the quantize type
|
|
|
|
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
|
|
|
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
|
|
|
if is_gguf_weight_type and loaded_shard_id is not None:
|
|
|
|
idx_map = {"q": 0, "k": 1, "v": 2}
|
|
|
|
param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
|
|
|
|
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
|
|
|
|
return
|
|
|
|
|
|
|
|
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
|
|
|
from gguf.constants import GGML_QUANT_SIZES
|
|
|
|
|
|
|
|
ori_shape = param.tensor_shape
|
|
|
|
weight_types = self.qweight_type.shard_weight_type.values()
|
|
|
|
row_size = []
|
|
|
|
for weight_type in weight_types:
|
|
|
|
block_size, type_size = GGML_QUANT_SIZES[weight_type]
|
|
|
|
row_size.append(ori_shape[1] // block_size * type_size)
|
|
|
|
q_shape = (ori_shape[0], max(row_size))
|
|
|
|
param.materialize(q_shape, dtype=loaded_weight.dtype)
|
|
|
|
|
2023-11-15 22:50:41 -08:00
|
|
|
param_data = param.data
|
|
|
|
output_dim = getattr(param, "output_dim", None)
|
2024-04-30 17:46:12 -04:00
|
|
|
# Special case for AQLM codebooks.
|
2024-04-23 13:59:33 -04:00
|
|
|
is_metadata = getattr(param, "is_metadata", False)
|
2024-05-23 17:29:18 -04:00
|
|
|
|
2024-06-30 19:06:27 -04:00
|
|
|
# Special case for per-tensor scales in fused case.
|
|
|
|
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
|
2024-03-01 14:47:51 -06:00
|
|
|
|
2023-11-15 22:50:41 -08:00
|
|
|
if loaded_shard_id is None:
|
2024-06-28 14:43:49 -04:00
|
|
|
# Loaded weight is already fused on disk (qkv/mlp).
|
2023-11-15 22:50:41 -08:00
|
|
|
if output_dim is None:
|
2024-07-10 07:43:24 +08:00
|
|
|
if needs_scalar_to_array:
|
2024-06-30 19:06:27 -04:00
|
|
|
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
|
|
|
param_data, loaded_weight, 0)
|
2024-06-28 14:43:49 -04:00
|
|
|
|
2023-11-15 22:50:41 -08:00
|
|
|
assert param_data.shape == loaded_weight.shape
|
|
|
|
param_data.copy_(loaded_weight)
|
|
|
|
return
|
|
|
|
shard_offsets = [
|
|
|
|
# (shard_id, shard_offset, shard_size)
|
|
|
|
("q", 0, self.total_num_heads * self.head_size),
|
|
|
|
("k", self.total_num_heads * self.head_size,
|
|
|
|
self.total_num_kv_heads * self.head_size),
|
|
|
|
("v", (self.total_num_heads + self.total_num_kv_heads) *
|
|
|
|
self.head_size, self.total_num_kv_heads * self.head_size),
|
|
|
|
]
|
|
|
|
packed_dim = getattr(param, "packed_dim", None)
|
|
|
|
for shard_id, shard_offset, shard_size in shard_offsets:
|
2024-04-30 17:46:12 -04:00
|
|
|
# Special case for Quantized Weights.
|
2023-11-15 22:50:41 -08:00
|
|
|
# If quantized, we need to adjust the offset and size to account
|
|
|
|
# for the packing.
|
|
|
|
if packed_dim == output_dim:
|
|
|
|
shard_size = shard_size // param.pack_factor
|
|
|
|
shard_offset = shard_offset // param.pack_factor
|
2024-03-01 14:47:51 -06:00
|
|
|
|
2024-04-30 17:46:12 -04:00
|
|
|
# Special case for Marlin.
|
2024-03-01 14:47:51 -06:00
|
|
|
shard_size, shard_offset = adjust_marlin_shard(
|
|
|
|
param, shard_size, shard_offset)
|
|
|
|
|
2023-11-15 22:50:41 -08:00
|
|
|
loaded_weight_shard = loaded_weight.narrow(
|
|
|
|
output_dim, shard_offset, shard_size)
|
|
|
|
self.weight_loader(param, loaded_weight_shard, shard_id)
|
|
|
|
return
|
|
|
|
|
|
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
|
|
assert loaded_shard_id in ["q", "k", "v"]
|
2024-05-23 17:29:18 -04:00
|
|
|
|
|
|
|
# If output dim is defined, use the default loading process.
|
2023-11-15 22:50:41 -08:00
|
|
|
if output_dim is not None:
|
|
|
|
if loaded_shard_id == "q":
|
|
|
|
shard_offset = 0
|
|
|
|
shard_size = self.num_heads * self.head_size
|
|
|
|
elif loaded_shard_id == "k":
|
|
|
|
shard_offset = self.num_heads * self.head_size
|
|
|
|
shard_size = self.num_kv_heads * self.head_size
|
|
|
|
elif loaded_shard_id == "v":
|
|
|
|
shard_offset = (self.num_heads +
|
|
|
|
self.num_kv_heads) * self.head_size
|
|
|
|
shard_size = self.num_kv_heads * self.head_size
|
2024-04-30 17:46:12 -04:00
|
|
|
# Special case for Quantized Weights.
|
2023-11-15 22:50:41 -08:00
|
|
|
# If quantized, we need to adjust the offset and size to account
|
|
|
|
# for the packing.
|
|
|
|
packed_dim = getattr(param, "packed_dim", None)
|
|
|
|
if packed_dim == output_dim:
|
|
|
|
shard_size = shard_size // param.pack_factor
|
|
|
|
shard_offset = shard_offset // param.pack_factor
|
2024-03-01 14:47:51 -06:00
|
|
|
|
2024-04-30 17:46:12 -04:00
|
|
|
# Special case for Marlin.
|
2024-03-01 14:47:51 -06:00
|
|
|
shard_size, shard_offset = adjust_marlin_shard(
|
|
|
|
param, shard_size, shard_offset)
|
|
|
|
|
2024-06-01 13:51:10 -07:00
|
|
|
use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
|
|
|
|
if use_bitsandbytes:
|
|
|
|
orig_qkv_offsets = {
|
|
|
|
"q": (0, self.num_heads * self.head_size),
|
|
|
|
"k": (self.num_heads * self.head_size,
|
|
|
|
self.num_kv_heads * self.head_size),
|
|
|
|
"v":
|
|
|
|
((self.num_heads + self.num_kv_heads) * self.head_size,
|
|
|
|
self.num_kv_heads * self.head_size),
|
|
|
|
"total":
|
|
|
|
((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
|
|
|
|
0)
|
|
|
|
}
|
|
|
|
shard_size, shard_offset = adjust_bitsandbytes_shard(
|
|
|
|
param, orig_qkv_offsets, loaded_shard_id)
|
|
|
|
|
2024-08-06 07:54:23 +08:00
|
|
|
if is_gguf_weight:
|
|
|
|
param.shard_id.append(loaded_shard_id)
|
|
|
|
param.shard_size[loaded_shard_id] = loaded_weight.shape
|
|
|
|
input_dim = getattr(param, "input_dim", None)
|
|
|
|
input_size = loaded_weight.shape[input_dim]
|
|
|
|
param_data = param_data.narrow(input_dim, 0, input_size)
|
|
|
|
|
2023-11-15 22:50:41 -08:00
|
|
|
param_data = param_data.narrow(output_dim, shard_offset,
|
|
|
|
shard_size)
|
2024-01-16 07:43:59 +08:00
|
|
|
if loaded_shard_id == "q":
|
|
|
|
shard_id = tp_rank
|
|
|
|
else:
|
|
|
|
shard_id = tp_rank // self.num_kv_head_replicas
|
2023-11-15 22:50:41 -08:00
|
|
|
start_idx = shard_id * shard_size
|
|
|
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
|
|
|
shard_size)
|
2024-04-30 17:46:12 -04:00
|
|
|
# Special case for for AQLM codebooks.
|
2024-04-23 13:59:33 -04:00
|
|
|
elif is_metadata:
|
|
|
|
# metadata indicates fixed size concatenated along dim 0
|
|
|
|
shard_size = loaded_weight.shape[0]
|
|
|
|
shard_index = ["q", "k", "v"].index(loaded_shard_id)
|
|
|
|
param_data = param_data.narrow(0, shard_index * shard_size,
|
|
|
|
shard_size)
|
2024-06-30 19:06:27 -04:00
|
|
|
# Special case for per-tensor scales in fused case.
|
|
|
|
elif needs_scalar_to_array:
|
|
|
|
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
2024-04-30 17:46:12 -04:00
|
|
|
param_data, loaded_weight, loaded_shard_id)
|
2023-11-15 22:50:41 -08:00
|
|
|
else:
|
2023-12-15 19:04:22 +08:00
|
|
|
ignore_warning = getattr(param, "ignore_warning", False)
|
|
|
|
if not ignore_warning:
|
|
|
|
logger.warning(
|
|
|
|
"Loading a weight without `output_dim` attribute in "
|
|
|
|
"QKVParallelLinear, assume the weight is the same "
|
|
|
|
"for all partitions.")
|
2024-05-23 17:29:18 -04:00
|
|
|
|
2023-11-15 22:50:41 -08:00
|
|
|
assert param_data.shape == loaded_weight.shape
|
|
|
|
param_data.copy_(loaded_weight)
|
|
|
|
|
|
|
|
|
2024-04-26 13:41:14 -07:00
|
|
|
class RowParallelLinear(LinearBase):
|
2023-11-15 22:50:41 -08:00
|
|
|
"""Linear layer with row parallelism.
|
|
|
|
|
|
|
|
The linear layer is defined as Y = XA + b. A is parallelized along
|
|
|
|
its first dimension and X along its second dimension as:
|
|
|
|
- -
|
|
|
|
| A_1 |
|
|
|
|
| . |
|
|
|
|
A = | . | X = [X_1, ..., X_p]
|
|
|
|
| . |
|
|
|
|
| A_p |
|
|
|
|
- -
|
|
|
|
Arguments:
|
|
|
|
input_size: first dimension of matrix A.
|
|
|
|
output_size: second dimension of matrix A.
|
|
|
|
bias: If true, add bias. Note that bias is not parallelized.
|
|
|
|
input_is_parallel: If true, we assume that the input is already
|
|
|
|
split across the GPUs and we do not split
|
|
|
|
again.
|
|
|
|
skip_bias_add: This was added to enable performance optimization where
|
|
|
|
bias can be fused with other element-wise operations.
|
|
|
|
We skip adding bias but instead return it.
|
|
|
|
params_dtype: Data type for the parameters.
|
2024-04-26 13:41:14 -07:00
|
|
|
quant_config: Quantization configure.
|
2023-11-15 22:50:41 -08:00
|
|
|
"""
|
|
|
|
|
2024-05-23 17:29:18 -04:00
|
|
|
def __init__(self,
|
|
|
|
input_size: int,
|
|
|
|
output_size: int,
|
|
|
|
bias: bool = True,
|
|
|
|
input_is_parallel: bool = True,
|
|
|
|
skip_bias_add: bool = False,
|
|
|
|
params_dtype: Optional[torch.dtype] = None,
|
|
|
|
reduce_results: bool = True,
|
2024-07-18 22:39:18 -04:00
|
|
|
quant_config: Optional[QuantizationConfig] = None,
|
2024-07-20 12:36:57 -04:00
|
|
|
prefix: str = ""):
|
2024-04-26 13:41:14 -07:00
|
|
|
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
2024-07-20 12:36:57 -04:00
|
|
|
quant_config, prefix)
|
2024-04-26 13:41:14 -07:00
|
|
|
|
2023-11-15 22:50:41 -08:00
|
|
|
self.input_is_parallel = input_is_parallel
|
|
|
|
self.reduce_results = reduce_results
|
|
|
|
|
|
|
|
# Divide the weight matrix along the last dimension.
|
2024-07-19 15:15:22 +02:00
|
|
|
self.tp_rank = get_tensor_model_parallel_rank()
|
2023-11-15 22:50:41 -08:00
|
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
|
|
self.input_size_per_partition = divide(input_size, self.tp_size)
|
2024-04-29 11:01:26 +09:00
|
|
|
assert self.quant_method is not None
|
2024-05-23 17:29:18 -04:00
|
|
|
self.quant_method.create_weights(
|
|
|
|
layer=self,
|
|
|
|
input_size_per_partition=self.input_size_per_partition,
|
|
|
|
output_partition_sizes=[self.output_size],
|
|
|
|
input_size=self.input_size,
|
|
|
|
output_size=self.output_size,
|
|
|
|
params_dtype=self.params_dtype,
|
2024-07-18 22:39:18 -04:00
|
|
|
weight_loader=self.weight_loader,
|
|
|
|
prefix=prefix)
|
2023-11-15 22:50:41 -08:00
|
|
|
if not reduce_results and (bias and not skip_bias_add):
|
|
|
|
raise ValueError("When not reduce the results, adding bias to the "
|
|
|
|
"results can lead to incorrect results")
|
|
|
|
|
|
|
|
if bias:
|
|
|
|
self.bias = Parameter(
|
2024-02-02 07:46:39 +08:00
|
|
|
torch.empty(self.output_size, dtype=params_dtype))
|
2023-11-15 22:50:41 -08:00
|
|
|
set_weight_attrs(self.bias, {
|
|
|
|
"output_dim": 0,
|
|
|
|
"weight_loader": self.weight_loader,
|
|
|
|
})
|
|
|
|
else:
|
|
|
|
self.register_parameter("bias", None)
|
|
|
|
|
|
|
|
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
|
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
|
|
input_dim = getattr(param, "input_dim", None)
|
2024-08-06 07:54:23 +08:00
|
|
|
|
|
|
|
# Special case for GGUF
|
|
|
|
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
|
|
|
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
|
|
|
if is_gguf_weight_type:
|
|
|
|
param.weight_type = loaded_weight.item()
|
|
|
|
|
|
|
|
# Materialize GGUF UninitializedParameter
|
|
|
|
if is_gguf_weight and isinstance(param, UninitializedParameter):
|
|
|
|
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
|
|
|
|
|
2023-11-15 22:50:41 -08:00
|
|
|
param_data = param.data
|
|
|
|
if input_dim is not None:
|
|
|
|
shard_size = param_data.shape[input_dim]
|
|
|
|
start_idx = tp_rank * shard_size
|
|
|
|
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
|
|
|
|
shard_size)
|
2024-05-23 17:29:18 -04:00
|
|
|
|
2024-06-28 13:49:57 -04:00
|
|
|
# Special case for loading scales off disk, which often do not
|
|
|
|
# have a shape (such as in the case of AutoFP8).
|
|
|
|
if len(loaded_weight.shape) == 0:
|
2024-05-23 17:29:18 -04:00
|
|
|
loaded_weight = loaded_weight.reshape(1)
|
|
|
|
|
2023-11-15 22:50:41 -08:00
|
|
|
assert param_data.shape == loaded_weight.shape
|
|
|
|
param_data.copy_(loaded_weight)
|
|
|
|
|
|
|
|
def forward(self, input_):
|
|
|
|
if self.input_is_parallel:
|
|
|
|
input_parallel = input_
|
|
|
|
else:
|
|
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
|
|
splitted_input = split_tensor_along_last_dim(
|
|
|
|
input_, num_partitions=self.tp_size)
|
|
|
|
input_parallel = splitted_input[tp_rank].contiguous()
|
|
|
|
|
|
|
|
# Matrix multiply.
|
2024-04-29 11:01:26 +09:00
|
|
|
assert self.quant_method is not None
|
2024-07-19 15:15:22 +02:00
|
|
|
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
|
|
|
# bias will not get added more than once in TP>1 case)
|
|
|
|
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
|
|
|
output_parallel = self.quant_method.apply(self,
|
|
|
|
input_parallel,
|
|
|
|
bias=bias_)
|
2023-11-15 22:50:41 -08:00
|
|
|
if self.reduce_results and self.tp_size > 1:
|
2024-07-19 15:15:22 +02:00
|
|
|
output = tensor_model_parallel_all_reduce(output_parallel)
|
2023-11-15 22:50:41 -08:00
|
|
|
else:
|
2024-07-19 15:15:22 +02:00
|
|
|
output = output_parallel
|
|
|
|
|
|
|
|
output_bias = self.bias if self.skip_bias_add else None
|
2023-11-15 22:50:41 -08:00
|
|
|
|
|
|
|
return output, output_bias
|
2024-05-01 12:18:14 +08:00
|
|
|
|
|
|
|
def extra_repr(self) -> str:
|
|
|
|
s = f"input_features={self.input_size_per_partition}"
|
|
|
|
s += f", output_features={self.output_size}"
|
|
|
|
s += f", bias={self.bias is not None}"
|
|
|
|
s += f", tp_size={self.tp_size}"
|
|
|
|
s += f", reduce_results={self.reduce_results}"
|
|
|
|
return s
|