[LoRA] Remove linear hack outside transformers backend (#14177)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2025-03-05 23:06:28 +08:00 committed by GitHub
parent 257e200a25
commit e17e4488bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 142 additions and 105 deletions

View File

@ -395,17 +395,20 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias) output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
# In transformers backend, x and output have extra batch dimension like
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
# therefore we need to flatten the batch dimensions.
if x.ndim == 3 and output.ndim == 3:
output = output.flatten(0, 1)
x = x.flatten(0, 1)
self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked, self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked,
self.lora_b_stacked, self.lora_b_stacked,
self.lora_bias_stacked, 1.0, self.lora_bias_stacked, 1.0,
self.output_slices) self.output_slices)
return output return output
@classmethod
def get_source_layer(cls, source_layer: nn.Module) -> type:
# Check parent_cls in case source_layer is a HFCompatibleLinear.
return getattr(source_layer, "parent_cls", type(source_layer))
class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
@ -418,7 +421,7 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
def forward( def forward(
self, input_: torch.Tensor self, input_: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""Forward of ReplicatedLinearWithLoRA """Forward of ReplicatedLinearWithLoRA
Args: Args:
@ -436,6 +439,10 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
output_bias = (self.base_layer.bias output_bias = (self.base_layer.bias
if self.base_layer.skip_bias_add else None) if self.base_layer.skip_bias_add else None)
if not self.base_layer.return_bias:
return output
return output, output_bias return output, output_bias
# ReplicatedLinear should always be replaced, regardless of the fully # ReplicatedLinear should always be replaced, regardless of the fully
@ -448,8 +455,7 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
packed_modules_list: List, packed_modules_list: List,
model_config: Optional[PretrainedConfig], model_config: Optional[PretrainedConfig],
) -> bool: ) -> bool:
source_layer = cls.get_source_layer(source_layer) return type(source_layer) is ReplicatedLinear
return source_layer is ReplicatedLinear
class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
@ -512,7 +518,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
def forward( def forward(
self, input_: torch.Tensor self, input_: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""Forward of ColumnParallelLinear """Forward of ColumnParallelLinear
Args: Args:
@ -532,6 +538,10 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
output = tensor_model_parallel_all_gather(output_parallel) output = tensor_model_parallel_all_gather(output_parallel)
else: else:
output = output_parallel output = output_parallel
if not self.base_layer.return_bias:
return output
output_bias = (self.base_layer.bias output_bias = (self.base_layer.bias
if self.base_layer.skip_bias_add else None) if self.base_layer.skip_bias_add else None)
return output, output_bias return output, output_bias
@ -545,9 +555,8 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
packed_modules_list: List, packed_modules_list: List,
model_config: Optional[PretrainedConfig], model_config: Optional[PretrainedConfig],
) -> bool: ) -> bool:
source_layer = cls.get_source_layer(source_layer) return type(source_layer) is ColumnParallelLinear or (
return source_layer is ColumnParallelLinear or ( type(source_layer) is MergedColumnParallelLinear
source_layer is MergedColumnParallelLinear
and len(packed_modules_list) == 1) and len(packed_modules_list) == 1)
@ -689,8 +698,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
packed_modules_list: List, packed_modules_list: List,
model_config: Optional[PretrainedConfig], model_config: Optional[PretrainedConfig],
) -> bool: ) -> bool:
source_layer = cls.get_source_layer(source_layer) return (type(source_layer) is MergedColumnParallelLinear
return (source_layer is MergedColumnParallelLinear
and len(packed_modules_list) == 2) and len(packed_modules_list) == 2)
@ -758,8 +766,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def can_replace_layer(cls, source_layer: nn.Module, def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: List, lora_config: LoRAConfig, packed_modules_list: List,
model_config: Optional[PretrainedConfig]) -> bool: model_config: Optional[PretrainedConfig]) -> bool:
source_layer = cls.get_source_layer(source_layer) return type(source_layer) is QKVParallelLinear and len(
return source_layer is QKVParallelLinear and len(
packed_modules_list) == 1 packed_modules_list) == 1
@ -820,8 +827,7 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
packed_modules_list: List, packed_modules_list: List,
model_config: Optional[PretrainedConfig], model_config: Optional[PretrainedConfig],
) -> bool: ) -> bool:
source_layer = cls.get_source_layer(source_layer) return (type(source_layer) is QKVParallelLinear
return (source_layer is QKVParallelLinear
and len(packed_modules_list) == 3) and len(packed_modules_list) == 3)
@ -855,7 +861,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
def forward( def forward(
self, input_: torch.Tensor self, input_: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""Forward of RowParallelLinear """Forward of RowParallelLinear
Args: Args:
@ -890,6 +896,10 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
else: else:
output = output_ output = output_
output_bias = self.base_layer.bias output_bias = self.base_layer.bias
if not self.base_layer.return_bias:
return output
return output, output_bias return output, output_bias
@property @property
@ -906,8 +916,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
packed_modules_list: List, packed_modules_list: List,
model_config: Optional[PretrainedConfig], model_config: Optional[PretrainedConfig],
) -> bool: ) -> bool:
source_layer = cls.get_source_layer(source_layer) return type(source_layer) is RowParallelLinear
return source_layer is RowParallelLinear
class LogitsProcessorWithLoRA(BaseLayerWithLoRA): class LogitsProcessorWithLoRA(BaseLayerWithLoRA):

View File

@ -67,16 +67,6 @@ def from_layer(layer: nn.Module,
packed_modules_list=packed_modules_list, packed_modules_list=packed_modules_list,
model_config=model_config): model_config=model_config):
instance_layer = lora_cls(layer) instance_layer = lora_cls(layer)
if layer.__class__.__name__ == "HFCompatibleLinear":
# HACK: Make the forward method compatible with the original
# forward method of the instance_layer.
original_forward = instance_layer.forward
def new_forward(input):
input = input.squeeze(0)
return original_forward(input)[0] # noqa: B023
instance_layer.forward = new_forward
instance_layer.create_lora_weights(max_loras, lora_config, instance_layer.create_lora_weights(max_loras, lora_config,
model_config) model_config)
return instance_layer return instance_layer

View File

@ -2,7 +2,7 @@
import itertools import itertools
from abc import abstractmethod from abc import abstractmethod
from typing import Optional from typing import Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -152,6 +152,7 @@ class LinearBase(torch.nn.Module):
skip_bias_add: If true, skip adding bias but instead return it. skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters. params_dtype: Data type for the parameters.
quant_config: Quantization configure. quant_config: Quantization configure.
return_bias: If true, return bias together with outputs in forward pass.
""" """
def __init__( def __init__(
@ -162,6 +163,8 @@ class LinearBase(torch.nn.Module):
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
*,
return_bias: bool = True,
): ):
super().__init__() super().__init__()
@ -178,9 +181,11 @@ class LinearBase(torch.nn.Module):
else: else:
self.quant_method = quant_config.get_quant_method(self, self.quant_method = quant_config.get_quant_method(self,
prefix=prefix) prefix=prefix)
self.return_bias = return_bias
def forward(self, def forward(
x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]: self, x: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
raise NotImplementedError raise NotImplementedError
@ -198,20 +203,25 @@ class ReplicatedLinear(LinearBase):
(e.g. model.layers.0.qkv_proj) (e.g. model.layers.0.qkv_proj)
""" """
def __init__(self, def __init__(
input_size: int, self,
output_size: int, input_size: int,
bias: bool = True, output_size: int,
skip_bias_add: bool = False, bias: bool = True,
params_dtype: Optional[torch.dtype] = None, skip_bias_add: bool = False,
quant_config: Optional[QuantizationConfig] = None, params_dtype: Optional[torch.dtype] = None,
prefix: str = ""): quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
):
super().__init__(input_size, super().__init__(input_size,
output_size, output_size,
skip_bias_add, skip_bias_add,
params_dtype, params_dtype,
quant_config, quant_config,
prefix=prefix) prefix=prefix,
return_bias=return_bias)
# All the linear layer supports quant method. # All the linear layer supports quant method.
assert self.quant_method is not None assert self.quant_method is not None
@ -254,12 +264,15 @@ class ReplicatedLinear(LinearBase):
f"to a parameter of size {param.size()}") f"to a parameter of size {param.size()}")
param.data.copy_(loaded_weight) param.data.copy_(loaded_weight)
def forward(self, def forward(
x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]: self, x: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None assert self.quant_method is not None
output = self.quant_method.apply(self, x, bias) output = self.quant_method.apply(self, x, bias)
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias return output, output_bias
def extra_repr(self) -> str: def extra_repr(self) -> str:
@ -293,16 +306,20 @@ class ColumnParallelLinear(LinearBase):
(e.g. model.layers.0.qkv_proj) (e.g. model.layers.0.qkv_proj)
""" """
def __init__(self, def __init__(
input_size: int, self,
output_size: int, input_size: int,
bias: bool = True, output_size: int,
gather_output: bool = False, bias: bool = True,
skip_bias_add: bool = False, gather_output: bool = False,
params_dtype: Optional[torch.dtype] = None, skip_bias_add: bool = False,
quant_config: Optional[QuantizationConfig] = None, params_dtype: Optional[torch.dtype] = None,
output_sizes: Optional[list[int]] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""): output_sizes: Optional[list[int]] = None,
prefix: str = "",
*,
return_bias: bool = True,
):
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = input_size self.input_size_per_partition = input_size
@ -315,8 +332,13 @@ class ColumnParallelLinear(LinearBase):
for output_size in self.output_sizes for output_size in self.output_sizes
] ]
super().__init__(input_size, output_size, skip_bias_add, params_dtype, super().__init__(input_size,
quant_config, prefix) output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias)
self.gather_output = gather_output self.gather_output = gather_output
@ -393,7 +415,9 @@ class ColumnParallelLinear(LinearBase):
loaded_weight = loaded_weight.reshape(1) loaded_weight = loaded_weight.reshape(1)
param.load_column_parallel_weight(loaded_weight=loaded_weight) param.load_column_parallel_weight(loaded_weight=loaded_weight)
def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]: def forward(
self, input_
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
# Matrix multiply. # Matrix multiply.
@ -405,6 +429,8 @@ class ColumnParallelLinear(LinearBase):
else: else:
output = output_parallel output = output_parallel
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias return output, output_bias
def extra_repr(self) -> str: def extra_repr(self) -> str:
@ -439,15 +465,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
(e.g. model.layers.0.qkv_proj) (e.g. model.layers.0.qkv_proj)
""" """
def __init__(self, def __init__(
input_size: int, self,
output_sizes: list[int], input_size: int,
bias: bool = True, output_sizes: list[int],
gather_output: bool = False, bias: bool = True,
skip_bias_add: bool = False, gather_output: bool = False,
params_dtype: Optional[torch.dtype] = None, skip_bias_add: bool = False,
quant_config: Optional[QuantizationConfig] = None, params_dtype: Optional[torch.dtype] = None,
prefix: str = ""): quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
):
self.output_sizes = output_sizes self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes) assert all(output_size % tp_size == 0 for output_size in output_sizes)
@ -458,7 +488,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_bias_add=skip_bias_add, skip_bias_add=skip_bias_add,
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix) prefix=prefix,
return_bias=return_bias)
def weight_loader(self, def weight_loader(self,
param: Parameter, param: Parameter,
@ -711,16 +742,20 @@ class QKVParallelLinear(ColumnParallelLinear):
(e.g. model.layers.0.qkv_proj) (e.g. model.layers.0.qkv_proj)
""" """
def __init__(self, def __init__(
hidden_size: int, self,
head_size: int, hidden_size: int,
total_num_heads: int, head_size: int,
total_num_kv_heads: Optional[int] = None, total_num_heads: int,
bias: bool = True, total_num_kv_heads: Optional[int] = None,
skip_bias_add: bool = False, bias: bool = True,
params_dtype: Optional[torch.dtype] = None, skip_bias_add: bool = False,
quant_config: Optional[QuantizationConfig] = None, params_dtype: Optional[torch.dtype] = None,
prefix: str = ""): quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.head_size = head_size self.head_size = head_size
self.total_num_heads = total_num_heads self.total_num_heads = total_num_heads
@ -753,7 +788,8 @@ class QKVParallelLinear(ColumnParallelLinear):
skip_bias_add=skip_bias_add, skip_bias_add=skip_bias_add,
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=quant_config, quant_config=quant_config,
prefix=prefix) prefix=prefix,
return_bias=return_bias)
def _get_shard_offset_mapping(self, loaded_shard_id: str): def _get_shard_offset_mapping(self, loaded_shard_id: str):
shard_offset_mapping = { shard_offset_mapping = {
@ -1048,16 +1084,20 @@ class RowParallelLinear(LinearBase):
quant_config: Quantization configure. quant_config: Quantization configure.
""" """
def __init__(self, def __init__(
input_size: int, self,
output_size: int, input_size: int,
bias: bool = True, output_size: int,
input_is_parallel: bool = True, bias: bool = True,
skip_bias_add: bool = False, input_is_parallel: bool = True,
params_dtype: Optional[torch.dtype] = None, skip_bias_add: bool = False,
reduce_results: bool = True, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = True,
prefix: str = ""): quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
):
# Divide the weight matrix along the first dimension. # Divide the weight matrix along the first dimension.
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
@ -1065,8 +1105,13 @@ class RowParallelLinear(LinearBase):
self.output_size_per_partition = output_size self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size] self.output_partition_sizes = [output_size]
super().__init__(input_size, output_size, skip_bias_add, params_dtype, super().__init__(input_size,
quant_config, prefix) output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias)
self.input_is_parallel = input_is_parallel self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results self.reduce_results = reduce_results
@ -1145,7 +1190,9 @@ class RowParallelLinear(LinearBase):
param.load_row_parallel_weight(loaded_weight=loaded_weight) param.load_row_parallel_weight(loaded_weight=loaded_weight)
def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]: def forward(
self, input_
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if self.input_is_parallel: if self.input_is_parallel:
input_parallel = input_ input_parallel = input_
else: else:
@ -1169,6 +1216,8 @@ class RowParallelLinear(LinearBase):
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias return output, output_bias
def extra_repr(self) -> str: def extra_repr(self) -> str:

View File

@ -96,23 +96,12 @@ def replace_linear_class(
"rowwise": RowParallelLinear, "rowwise": RowParallelLinear,
}.get(style, ReplicatedLinear) }.get(style, ReplicatedLinear)
class HFCompatibleLinear(vllm_linear_cls): return vllm_linear_cls(
"""
Wrapper class that removes `output_bias` from returned output.
"""
# NOTE: The LoRA layer needs to use `parent_cls`.
@property
def parent_cls(self) -> type:
return vllm_linear_cls
def forward(self, input: torch.Tensor) -> torch.Tensor:
return super().forward(input)[0]
return HFCompatibleLinear(
input_size=linear.in_features, input_size=linear.in_features,
output_size=linear.out_features, output_size=linear.out_features,
bias=linear.bias is not None, bias=linear.bias is not None,
quant_config=quant_config, quant_config=quant_config,
return_bias=False,
) )