diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 5a4d991d..ff1b6501 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -395,17 +395,20 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + + # In transformers backend, x and output have extra batch dimension like + # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim), + # therefore we need to flatten the batch dimensions. + if x.ndim == 3 and output.ndim == 3: + output = output.flatten(0, 1) + x = x.flatten(0, 1) + self.punica_wrapper.add_lora_linear(output, x, self.lora_a_stacked, self.lora_b_stacked, self.lora_bias_stacked, 1.0, self.output_slices) return output - @classmethod - def get_source_layer(cls, source_layer: nn.Module) -> type: - # Check parent_cls in case source_layer is a HFCompatibleLinear. - return getattr(source_layer, "parent_cls", type(source_layer)) - class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): @@ -418,7 +421,7 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): def forward( self, input_: torch.Tensor - ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: """Forward of ReplicatedLinearWithLoRA Args: @@ -436,6 +439,10 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): output_bias = (self.base_layer.bias if self.base_layer.skip_bias_add else None) + + if not self.base_layer.return_bias: + return output + return output, output_bias # ReplicatedLinear should always be replaced, regardless of the fully @@ -448,8 +455,7 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA): packed_modules_list: List, model_config: Optional[PretrainedConfig], ) -> bool: - source_layer = cls.get_source_layer(source_layer) - return source_layer is ReplicatedLinear + return type(source_layer) is ReplicatedLinear class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): @@ -512,7 +518,7 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): def forward( self, input_: torch.Tensor - ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: """Forward of ColumnParallelLinear Args: @@ -532,6 +538,10 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): output = tensor_model_parallel_all_gather(output_parallel) else: output = output_parallel + + if not self.base_layer.return_bias: + return output + output_bias = (self.base_layer.bias if self.base_layer.skip_bias_add else None) return output, output_bias @@ -545,9 +555,8 @@ class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA): packed_modules_list: List, model_config: Optional[PretrainedConfig], ) -> bool: - source_layer = cls.get_source_layer(source_layer) - return source_layer is ColumnParallelLinear or ( - source_layer is MergedColumnParallelLinear + return type(source_layer) is ColumnParallelLinear or ( + type(source_layer) is MergedColumnParallelLinear and len(packed_modules_list) == 1) @@ -689,8 +698,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): packed_modules_list: List, model_config: Optional[PretrainedConfig], ) -> bool: - source_layer = cls.get_source_layer(source_layer) - return (source_layer is MergedColumnParallelLinear + return (type(source_layer) is MergedColumnParallelLinear and len(packed_modules_list) == 2) @@ -758,8 +766,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): def can_replace_layer(cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: List, model_config: Optional[PretrainedConfig]) -> bool: - source_layer = cls.get_source_layer(source_layer) - return source_layer is QKVParallelLinear and len( + return type(source_layer) is QKVParallelLinear and len( packed_modules_list) == 1 @@ -820,8 +827,7 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA): packed_modules_list: List, model_config: Optional[PretrainedConfig], ) -> bool: - source_layer = cls.get_source_layer(source_layer) - return (source_layer is QKVParallelLinear + return (type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 3) @@ -855,7 +861,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): def forward( self, input_: torch.Tensor - ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: """Forward of RowParallelLinear Args: @@ -890,6 +896,10 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): else: output = output_ output_bias = self.base_layer.bias + + if not self.base_layer.return_bias: + return output + return output, output_bias @property @@ -906,8 +916,7 @@ class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA): packed_modules_list: List, model_config: Optional[PretrainedConfig], ) -> bool: - source_layer = cls.get_source_layer(source_layer) - return source_layer is RowParallelLinear + return type(source_layer) is RowParallelLinear class LogitsProcessorWithLoRA(BaseLayerWithLoRA): diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 9f1b14b4..610cbf87 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -67,16 +67,6 @@ def from_layer(layer: nn.Module, packed_modules_list=packed_modules_list, model_config=model_config): instance_layer = lora_cls(layer) - if layer.__class__.__name__ == "HFCompatibleLinear": - # HACK: Make the forward method compatible with the original - # forward method of the instance_layer. - original_forward = instance_layer.forward - - def new_forward(input): - input = input.squeeze(0) - return original_forward(input)[0] # noqa: B023 - - instance_layer.forward = new_forward instance_layer.create_lora_weights(max_loras, lora_config, model_config) return instance_layer diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index b9c85aaf..600284a8 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -2,7 +2,7 @@ import itertools from abc import abstractmethod -from typing import Optional +from typing import Optional, Union import torch import torch.nn.functional as F @@ -152,6 +152,7 @@ class LinearBase(torch.nn.Module): skip_bias_add: If true, skip adding bias but instead return it. params_dtype: Data type for the parameters. quant_config: Quantization configure. + return_bias: If true, return bias together with outputs in forward pass. """ def __init__( @@ -162,6 +163,8 @@ class LinearBase(torch.nn.Module): params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + *, + return_bias: bool = True, ): super().__init__() @@ -178,9 +181,11 @@ class LinearBase(torch.nn.Module): else: self.quant_method = quant_config.get_quant_method(self, prefix=prefix) + self.return_bias = return_bias - def forward(self, - x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]: + def forward( + self, x: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: raise NotImplementedError @@ -198,20 +203,25 @@ class ReplicatedLinear(LinearBase): (e.g. model.layers.0.qkv_proj) """ - def __init__(self, - input_size: int, - output_size: int, - bias: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + ): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config, - prefix=prefix) + prefix=prefix, + return_bias=return_bias) # All the linear layer supports quant method. assert self.quant_method is not None @@ -254,12 +264,15 @@ class ReplicatedLinear(LinearBase): f"to a parameter of size {param.size()}") param.data.copy_(loaded_weight) - def forward(self, - x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]: + def forward( + self, x: torch.Tensor + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: bias = self.bias if not self.skip_bias_add else None assert self.quant_method is not None output = self.quant_method.apply(self, x, bias) output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output return output, output_bias def extra_repr(self) -> str: @@ -293,16 +306,20 @@ class ColumnParallelLinear(LinearBase): (e.g. model.layers.0.qkv_proj) """ - def __init__(self, - input_size: int, - output_size: int, - bias: bool = True, - gather_output: bool = False, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - output_sizes: Optional[list[int]] = None, - prefix: str = ""): + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + output_sizes: Optional[list[int]] = None, + prefix: str = "", + *, + return_bias: bool = True, + ): # Divide the weight matrix along the last dimension. self.tp_size = get_tensor_model_parallel_world_size() self.input_size_per_partition = input_size @@ -315,8 +332,13 @@ class ColumnParallelLinear(LinearBase): for output_size in self.output_sizes ] - super().__init__(input_size, output_size, skip_bias_add, params_dtype, - quant_config, prefix) + super().__init__(input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias) self.gather_output = gather_output @@ -393,7 +415,9 @@ class ColumnParallelLinear(LinearBase): loaded_weight = loaded_weight.reshape(1) param.load_column_parallel_weight(loaded_weight=loaded_weight) - def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]: + def forward( + self, input_ + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: bias = self.bias if not self.skip_bias_add else None # Matrix multiply. @@ -405,6 +429,8 @@ class ColumnParallelLinear(LinearBase): else: output = output_parallel output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output return output, output_bias def extra_repr(self) -> str: @@ -439,15 +465,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear): (e.g. model.layers.0.qkv_proj) """ - def __init__(self, - input_size: int, - output_sizes: list[int], - bias: bool = True, - gather_output: bool = False, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + input_size: int, + output_sizes: list[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + ): self.output_sizes = output_sizes tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) @@ -458,7 +488,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): skip_bias_add=skip_bias_add, params_dtype=params_dtype, quant_config=quant_config, - prefix=prefix) + prefix=prefix, + return_bias=return_bias) def weight_loader(self, param: Parameter, @@ -711,16 +742,20 @@ class QKVParallelLinear(ColumnParallelLinear): (e.g. model.layers.0.qkv_proj) """ - def __init__(self, - hidden_size: int, - head_size: int, - total_num_heads: int, - total_num_kv_heads: Optional[int] = None, - bias: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: Optional[int] = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + ): self.hidden_size = hidden_size self.head_size = head_size self.total_num_heads = total_num_heads @@ -753,7 +788,8 @@ class QKVParallelLinear(ColumnParallelLinear): skip_bias_add=skip_bias_add, params_dtype=params_dtype, quant_config=quant_config, - prefix=prefix) + prefix=prefix, + return_bias=return_bias) def _get_shard_offset_mapping(self, loaded_shard_id: str): shard_offset_mapping = { @@ -1048,16 +1084,20 @@ class RowParallelLinear(LinearBase): quant_config: Quantization configure. """ - def __init__(self, - input_size: int, - output_size: int, - bias: bool = True, - input_is_parallel: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - reduce_results: bool = True, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + ): # Divide the weight matrix along the first dimension. self.tp_rank = get_tensor_model_parallel_rank() self.tp_size = get_tensor_model_parallel_world_size() @@ -1065,8 +1105,13 @@ class RowParallelLinear(LinearBase): self.output_size_per_partition = output_size self.output_partition_sizes = [output_size] - super().__init__(input_size, output_size, skip_bias_add, params_dtype, - quant_config, prefix) + super().__init__(input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias) self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results @@ -1145,7 +1190,9 @@ class RowParallelLinear(LinearBase): param.load_row_parallel_weight(loaded_weight=loaded_weight) - def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]: + def forward( + self, input_ + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: if self.input_is_parallel: input_parallel = input_ else: @@ -1169,6 +1216,8 @@ class RowParallelLinear(LinearBase): output_bias = self.bias if self.skip_bias_add else None + if not self.return_bias: + return output return output, output_bias def extra_repr(self) -> str: diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index a6bfdebb..be788d63 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -96,23 +96,12 @@ def replace_linear_class( "rowwise": RowParallelLinear, }.get(style, ReplicatedLinear) - class HFCompatibleLinear(vllm_linear_cls): - """ - Wrapper class that removes `output_bias` from returned output. - """ - # NOTE: The LoRA layer needs to use `parent_cls`. - @property - def parent_cls(self) -> type: - return vllm_linear_cls - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return super().forward(input)[0] - - return HFCompatibleLinear( + return vllm_linear_cls( input_size=linear.in_features, output_size=linear.out_features, bias=linear.bias is not None, quant_config=quant_config, + return_bias=False, )