[LoRA] Remove linear hack outside transformers backend (#14177)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
257e200a25
commit
e17e4488bd
@ -395,17 +395,20 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||||
|
|
||||||
|
# In transformers backend, x and output have extra batch dimension like
|
||||||
|
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
|
||||||
|
# therefore we need to flatten the batch dimensions.
|
||||||
|
if x.ndim == 3 and output.ndim == 3:
|
||||||
|
output = output.flatten(0, 1)
|
||||||
|
x = x.flatten(0, 1)
|
||||||
|
|
||||||
self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked,
|
self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked,
|
||||||
self.lora_b_stacked,
|
self.lora_b_stacked,
|
||||||
self.lora_bias_stacked, 1.0,
|
self.lora_bias_stacked, 1.0,
|
||||||
self.output_slices)
|
self.output_slices)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_source_layer(cls, source_layer: nn.Module) -> type:
|
|
||||||
# Check parent_cls in case source_layer is a HFCompatibleLinear.
|
|
||||||
return getattr(source_layer, "parent_cls", type(source_layer))
|
|
||||||
|
|
||||||
|
|
||||||
class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
|
class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||||
|
|
||||||
@ -418,7 +421,7 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_: torch.Tensor
|
self, input_: torch.Tensor
|
||||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
||||||
"""Forward of ReplicatedLinearWithLoRA
|
"""Forward of ReplicatedLinearWithLoRA
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -436,6 +439,10 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
|
|||||||
|
|
||||||
output_bias = (self.base_layer.bias
|
output_bias = (self.base_layer.bias
|
||||||
if self.base_layer.skip_bias_add else None)
|
if self.base_layer.skip_bias_add else None)
|
||||||
|
|
||||||
|
if not self.base_layer.return_bias:
|
||||||
|
return output
|
||||||
|
|
||||||
return output, output_bias
|
return output, output_bias
|
||||||
|
|
||||||
# ReplicatedLinear should always be replaced, regardless of the fully
|
# ReplicatedLinear should always be replaced, regardless of the fully
|
||||||
@ -448,8 +455,7 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
|
|||||||
packed_modules_list: List,
|
packed_modules_list: List,
|
||||||
model_config: Optional[PretrainedConfig],
|
model_config: Optional[PretrainedConfig],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
source_layer = cls.get_source_layer(source_layer)
|
return type(source_layer) is ReplicatedLinear
|
||||||
return source_layer is ReplicatedLinear
|
|
||||||
|
|
||||||
|
|
||||||
class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||||
@ -512,7 +518,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_: torch.Tensor
|
self, input_: torch.Tensor
|
||||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
||||||
"""Forward of ColumnParallelLinear
|
"""Forward of ColumnParallelLinear
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -532,6 +538,10 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
|||||||
output = tensor_model_parallel_all_gather(output_parallel)
|
output = tensor_model_parallel_all_gather(output_parallel)
|
||||||
else:
|
else:
|
||||||
output = output_parallel
|
output = output_parallel
|
||||||
|
|
||||||
|
if not self.base_layer.return_bias:
|
||||||
|
return output
|
||||||
|
|
||||||
output_bias = (self.base_layer.bias
|
output_bias = (self.base_layer.bias
|
||||||
if self.base_layer.skip_bias_add else None)
|
if self.base_layer.skip_bias_add else None)
|
||||||
return output, output_bias
|
return output, output_bias
|
||||||
@ -545,9 +555,8 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
|||||||
packed_modules_list: List,
|
packed_modules_list: List,
|
||||||
model_config: Optional[PretrainedConfig],
|
model_config: Optional[PretrainedConfig],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
source_layer = cls.get_source_layer(source_layer)
|
return type(source_layer) is ColumnParallelLinear or (
|
||||||
return source_layer is ColumnParallelLinear or (
|
type(source_layer) is MergedColumnParallelLinear
|
||||||
source_layer is MergedColumnParallelLinear
|
|
||||||
and len(packed_modules_list) == 1)
|
and len(packed_modules_list) == 1)
|
||||||
|
|
||||||
|
|
||||||
@ -689,8 +698,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
packed_modules_list: List,
|
packed_modules_list: List,
|
||||||
model_config: Optional[PretrainedConfig],
|
model_config: Optional[PretrainedConfig],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
source_layer = cls.get_source_layer(source_layer)
|
return (type(source_layer) is MergedColumnParallelLinear
|
||||||
return (source_layer is MergedColumnParallelLinear
|
|
||||||
and len(packed_modules_list) == 2)
|
and len(packed_modules_list) == 2)
|
||||||
|
|
||||||
|
|
||||||
@ -758,8 +766,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||||||
def can_replace_layer(cls, source_layer: nn.Module,
|
def can_replace_layer(cls, source_layer: nn.Module,
|
||||||
lora_config: LoRAConfig, packed_modules_list: List,
|
lora_config: LoRAConfig, packed_modules_list: List,
|
||||||
model_config: Optional[PretrainedConfig]) -> bool:
|
model_config: Optional[PretrainedConfig]) -> bool:
|
||||||
source_layer = cls.get_source_layer(source_layer)
|
return type(source_layer) is QKVParallelLinear and len(
|
||||||
return source_layer is QKVParallelLinear and len(
|
|
||||||
packed_modules_list) == 1
|
packed_modules_list) == 1
|
||||||
|
|
||||||
|
|
||||||
@ -820,8 +827,7 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
|
|||||||
packed_modules_list: List,
|
packed_modules_list: List,
|
||||||
model_config: Optional[PretrainedConfig],
|
model_config: Optional[PretrainedConfig],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
source_layer = cls.get_source_layer(source_layer)
|
return (type(source_layer) is QKVParallelLinear
|
||||||
return (source_layer is QKVParallelLinear
|
|
||||||
and len(packed_modules_list) == 3)
|
and len(packed_modules_list) == 3)
|
||||||
|
|
||||||
|
|
||||||
@ -855,7 +861,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, input_: torch.Tensor
|
self, input_: torch.Tensor
|
||||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
|
||||||
"""Forward of RowParallelLinear
|
"""Forward of RowParallelLinear
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -890,6 +896,10 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
|||||||
else:
|
else:
|
||||||
output = output_
|
output = output_
|
||||||
output_bias = self.base_layer.bias
|
output_bias = self.base_layer.bias
|
||||||
|
|
||||||
|
if not self.base_layer.return_bias:
|
||||||
|
return output
|
||||||
|
|
||||||
return output, output_bias
|
return output, output_bias
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -906,8 +916,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
|||||||
packed_modules_list: List,
|
packed_modules_list: List,
|
||||||
model_config: Optional[PretrainedConfig],
|
model_config: Optional[PretrainedConfig],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
source_layer = cls.get_source_layer(source_layer)
|
return type(source_layer) is RowParallelLinear
|
||||||
return source_layer is RowParallelLinear
|
|
||||||
|
|
||||||
|
|
||||||
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||||
|
@ -67,16 +67,6 @@ def from_layer(layer: nn.Module,
|
|||||||
packed_modules_list=packed_modules_list,
|
packed_modules_list=packed_modules_list,
|
||||||
model_config=model_config):
|
model_config=model_config):
|
||||||
instance_layer = lora_cls(layer)
|
instance_layer = lora_cls(layer)
|
||||||
if layer.__class__.__name__ == "HFCompatibleLinear":
|
|
||||||
# HACK: Make the forward method compatible with the original
|
|
||||||
# forward method of the instance_layer.
|
|
||||||
original_forward = instance_layer.forward
|
|
||||||
|
|
||||||
def new_forward(input):
|
|
||||||
input = input.squeeze(0)
|
|
||||||
return original_forward(input)[0] # noqa: B023
|
|
||||||
|
|
||||||
instance_layer.forward = new_forward
|
|
||||||
instance_layer.create_lora_weights(max_loras, lora_config,
|
instance_layer.create_lora_weights(max_loras, lora_config,
|
||||||
model_config)
|
model_config)
|
||||||
return instance_layer
|
return instance_layer
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -152,6 +152,7 @@ class LinearBase(torch.nn.Module):
|
|||||||
skip_bias_add: If true, skip adding bias but instead return it.
|
skip_bias_add: If true, skip adding bias but instead return it.
|
||||||
params_dtype: Data type for the parameters.
|
params_dtype: Data type for the parameters.
|
||||||
quant_config: Quantization configure.
|
quant_config: Quantization configure.
|
||||||
|
return_bias: If true, return bias together with outputs in forward pass.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -162,6 +163,8 @@ class LinearBase(torch.nn.Module):
|
|||||||
params_dtype: Optional[torch.dtype] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
*,
|
||||||
|
return_bias: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -178,9 +181,11 @@ class LinearBase(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.quant_method = quant_config.get_quant_method(self,
|
self.quant_method = quant_config.get_quant_method(self,
|
||||||
prefix=prefix)
|
prefix=prefix)
|
||||||
|
self.return_bias = return_bias
|
||||||
|
|
||||||
def forward(self,
|
def forward(
|
||||||
x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]:
|
self, x: torch.Tensor
|
||||||
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@ -198,20 +203,25 @@ class ReplicatedLinear(LinearBase):
|
|||||||
(e.g. model.layers.0.qkv_proj)
|
(e.g. model.layers.0.qkv_proj)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
input_size: int,
|
self,
|
||||||
output_size: int,
|
input_size: int,
|
||||||
bias: bool = True,
|
output_size: int,
|
||||||
skip_bias_add: bool = False,
|
bias: bool = True,
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
skip_bias_add: bool = False,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
prefix: str = ""):
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
*,
|
||||||
|
return_bias: bool = True,
|
||||||
|
):
|
||||||
super().__init__(input_size,
|
super().__init__(input_size,
|
||||||
output_size,
|
output_size,
|
||||||
skip_bias_add,
|
skip_bias_add,
|
||||||
params_dtype,
|
params_dtype,
|
||||||
quant_config,
|
quant_config,
|
||||||
prefix=prefix)
|
prefix=prefix,
|
||||||
|
return_bias=return_bias)
|
||||||
|
|
||||||
# All the linear layer supports quant method.
|
# All the linear layer supports quant method.
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
@ -254,12 +264,15 @@ class ReplicatedLinear(LinearBase):
|
|||||||
f"to a parameter of size {param.size()}")
|
f"to a parameter of size {param.size()}")
|
||||||
param.data.copy_(loaded_weight)
|
param.data.copy_(loaded_weight)
|
||||||
|
|
||||||
def forward(self,
|
def forward(
|
||||||
x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]:
|
self, x: torch.Tensor
|
||||||
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
output = self.quant_method.apply(self, x, bias)
|
output = self.quant_method.apply(self, x, bias)
|
||||||
output_bias = self.bias if self.skip_bias_add else None
|
output_bias = self.bias if self.skip_bias_add else None
|
||||||
|
if not self.return_bias:
|
||||||
|
return output
|
||||||
return output, output_bias
|
return output, output_bias
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
@ -293,16 +306,20 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
(e.g. model.layers.0.qkv_proj)
|
(e.g. model.layers.0.qkv_proj)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
input_size: int,
|
self,
|
||||||
output_size: int,
|
input_size: int,
|
||||||
bias: bool = True,
|
output_size: int,
|
||||||
gather_output: bool = False,
|
bias: bool = True,
|
||||||
skip_bias_add: bool = False,
|
gather_output: bool = False,
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
skip_bias_add: bool = False,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
output_sizes: Optional[list[int]] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = ""):
|
output_sizes: Optional[list[int]] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
*,
|
||||||
|
return_bias: bool = True,
|
||||||
|
):
|
||||||
# Divide the weight matrix along the last dimension.
|
# Divide the weight matrix along the last dimension.
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.input_size_per_partition = input_size
|
self.input_size_per_partition = input_size
|
||||||
@ -315,8 +332,13 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
for output_size in self.output_sizes
|
for output_size in self.output_sizes
|
||||||
]
|
]
|
||||||
|
|
||||||
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
super().__init__(input_size,
|
||||||
quant_config, prefix)
|
output_size,
|
||||||
|
skip_bias_add,
|
||||||
|
params_dtype,
|
||||||
|
quant_config,
|
||||||
|
prefix,
|
||||||
|
return_bias=return_bias)
|
||||||
|
|
||||||
self.gather_output = gather_output
|
self.gather_output = gather_output
|
||||||
|
|
||||||
@ -393,7 +415,9 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
loaded_weight = loaded_weight.reshape(1)
|
loaded_weight = loaded_weight.reshape(1)
|
||||||
param.load_column_parallel_weight(loaded_weight=loaded_weight)
|
param.load_column_parallel_weight(loaded_weight=loaded_weight)
|
||||||
|
|
||||||
def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]:
|
def forward(
|
||||||
|
self, input_
|
||||||
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
@ -405,6 +429,8 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
else:
|
else:
|
||||||
output = output_parallel
|
output = output_parallel
|
||||||
output_bias = self.bias if self.skip_bias_add else None
|
output_bias = self.bias if self.skip_bias_add else None
|
||||||
|
if not self.return_bias:
|
||||||
|
return output
|
||||||
return output, output_bias
|
return output, output_bias
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
@ -439,15 +465,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
(e.g. model.layers.0.qkv_proj)
|
(e.g. model.layers.0.qkv_proj)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
input_size: int,
|
self,
|
||||||
output_sizes: list[int],
|
input_size: int,
|
||||||
bias: bool = True,
|
output_sizes: list[int],
|
||||||
gather_output: bool = False,
|
bias: bool = True,
|
||||||
skip_bias_add: bool = False,
|
gather_output: bool = False,
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
skip_bias_add: bool = False,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
prefix: str = ""):
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
*,
|
||||||
|
return_bias: bool = True,
|
||||||
|
):
|
||||||
self.output_sizes = output_sizes
|
self.output_sizes = output_sizes
|
||||||
tp_size = get_tensor_model_parallel_world_size()
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
assert all(output_size % tp_size == 0 for output_size in output_sizes)
|
assert all(output_size % tp_size == 0 for output_size in output_sizes)
|
||||||
@ -458,7 +488,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
skip_bias_add=skip_bias_add,
|
skip_bias_add=skip_bias_add,
|
||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=prefix)
|
prefix=prefix,
|
||||||
|
return_bias=return_bias)
|
||||||
|
|
||||||
def weight_loader(self,
|
def weight_loader(self,
|
||||||
param: Parameter,
|
param: Parameter,
|
||||||
@ -711,16 +742,20 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
(e.g. model.layers.0.qkv_proj)
|
(e.g. model.layers.0.qkv_proj)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
hidden_size: int,
|
self,
|
||||||
head_size: int,
|
hidden_size: int,
|
||||||
total_num_heads: int,
|
head_size: int,
|
||||||
total_num_kv_heads: Optional[int] = None,
|
total_num_heads: int,
|
||||||
bias: bool = True,
|
total_num_kv_heads: Optional[int] = None,
|
||||||
skip_bias_add: bool = False,
|
bias: bool = True,
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
skip_bias_add: bool = False,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
prefix: str = ""):
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
*,
|
||||||
|
return_bias: bool = True,
|
||||||
|
):
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.total_num_heads = total_num_heads
|
self.total_num_heads = total_num_heads
|
||||||
@ -753,7 +788,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
skip_bias_add=skip_bias_add,
|
skip_bias_add=skip_bias_add,
|
||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=prefix)
|
prefix=prefix,
|
||||||
|
return_bias=return_bias)
|
||||||
|
|
||||||
def _get_shard_offset_mapping(self, loaded_shard_id: str):
|
def _get_shard_offset_mapping(self, loaded_shard_id: str):
|
||||||
shard_offset_mapping = {
|
shard_offset_mapping = {
|
||||||
@ -1048,16 +1084,20 @@ class RowParallelLinear(LinearBase):
|
|||||||
quant_config: Quantization configure.
|
quant_config: Quantization configure.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(
|
||||||
input_size: int,
|
self,
|
||||||
output_size: int,
|
input_size: int,
|
||||||
bias: bool = True,
|
output_size: int,
|
||||||
input_is_parallel: bool = True,
|
bias: bool = True,
|
||||||
skip_bias_add: bool = False,
|
input_is_parallel: bool = True,
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
skip_bias_add: bool = False,
|
||||||
reduce_results: bool = True,
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
reduce_results: bool = True,
|
||||||
prefix: str = ""):
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
*,
|
||||||
|
return_bias: bool = True,
|
||||||
|
):
|
||||||
# Divide the weight matrix along the first dimension.
|
# Divide the weight matrix along the first dimension.
|
||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
@ -1065,8 +1105,13 @@ class RowParallelLinear(LinearBase):
|
|||||||
self.output_size_per_partition = output_size
|
self.output_size_per_partition = output_size
|
||||||
self.output_partition_sizes = [output_size]
|
self.output_partition_sizes = [output_size]
|
||||||
|
|
||||||
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
super().__init__(input_size,
|
||||||
quant_config, prefix)
|
output_size,
|
||||||
|
skip_bias_add,
|
||||||
|
params_dtype,
|
||||||
|
quant_config,
|
||||||
|
prefix,
|
||||||
|
return_bias=return_bias)
|
||||||
|
|
||||||
self.input_is_parallel = input_is_parallel
|
self.input_is_parallel = input_is_parallel
|
||||||
self.reduce_results = reduce_results
|
self.reduce_results = reduce_results
|
||||||
@ -1145,7 +1190,9 @@ class RowParallelLinear(LinearBase):
|
|||||||
|
|
||||||
param.load_row_parallel_weight(loaded_weight=loaded_weight)
|
param.load_row_parallel_weight(loaded_weight=loaded_weight)
|
||||||
|
|
||||||
def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]:
|
def forward(
|
||||||
|
self, input_
|
||||||
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||||
if self.input_is_parallel:
|
if self.input_is_parallel:
|
||||||
input_parallel = input_
|
input_parallel = input_
|
||||||
else:
|
else:
|
||||||
@ -1169,6 +1216,8 @@ class RowParallelLinear(LinearBase):
|
|||||||
|
|
||||||
output_bias = self.bias if self.skip_bias_add else None
|
output_bias = self.bias if self.skip_bias_add else None
|
||||||
|
|
||||||
|
if not self.return_bias:
|
||||||
|
return output
|
||||||
return output, output_bias
|
return output, output_bias
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
|
@ -96,23 +96,12 @@ def replace_linear_class(
|
|||||||
"rowwise": RowParallelLinear,
|
"rowwise": RowParallelLinear,
|
||||||
}.get(style, ReplicatedLinear)
|
}.get(style, ReplicatedLinear)
|
||||||
|
|
||||||
class HFCompatibleLinear(vllm_linear_cls):
|
return vllm_linear_cls(
|
||||||
"""
|
|
||||||
Wrapper class that removes `output_bias` from returned output.
|
|
||||||
"""
|
|
||||||
# NOTE: The LoRA layer needs to use `parent_cls`.
|
|
||||||
@property
|
|
||||||
def parent_cls(self) -> type:
|
|
||||||
return vllm_linear_cls
|
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
||||||
return super().forward(input)[0]
|
|
||||||
|
|
||||||
return HFCompatibleLinear(
|
|
||||||
input_size=linear.in_features,
|
input_size=linear.in_features,
|
||||||
output_size=linear.out_features,
|
output_size=linear.out_features,
|
||||||
bias=linear.bias is not None,
|
bias=linear.bias is not None,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
return_bias=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user