Remove ScaledActivation for AWQ (#10057)

Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
Michael Goin 2024-11-06 09:27:06 -05:00 committed by GitHub
parent 406d4cc480
commit 399c798608
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
34 changed files with 19 additions and 124 deletions

View File

@ -9,7 +9,6 @@ import torch.nn.functional as F
from vllm.distributed import (divide, get_tensor_model_parallel_rank, from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import LazyDict from vllm.utils import LazyDict
@ -277,28 +276,14 @@ _ACTIVATION_REGISTRY = LazyDict({
}) })
def get_act_fn( def get_act_fn(act_fn_name: str) -> nn.Module:
act_fn_name: str,
quant_config: Optional[QuantizationConfig] = None,
intermediate_size: Optional[int] = None,
input_is_parallel: bool = True,
params_dtype: Optional[torch.dtype] = None,
) -> nn.Module:
"""Get an activation function by name.""" """Get an activation function by name."""
act_fn_name = act_fn_name.lower() act_fn_name = act_fn_name.lower()
if act_fn_name not in _ACTIVATION_REGISTRY: if act_fn_name not in _ACTIVATION_REGISTRY:
raise ValueError( raise ValueError(
f"Activation function {act_fn_name!r} is not supported.") f"Activation function {act_fn_name!r} is not supported.")
act_fn = _ACTIVATION_REGISTRY[act_fn_name] return _ACTIVATION_REGISTRY[act_fn_name]
if (quant_config is not None
and act_fn_name in quant_config.get_scaled_act_names()):
if intermediate_size is None:
raise ValueError("intermediate_size must be specified for scaled "
"activation functions.")
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
params_dtype)
return act_fn
_ACTIVATION_AND_MUL_REGISTRY = LazyDict({ _ACTIVATION_AND_MUL_REGISTRY = LazyDict({
@ -307,25 +292,11 @@ _ACTIVATION_AND_MUL_REGISTRY = LazyDict({
}) })
def get_act_and_mul_fn( def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
act_fn_name: str,
quant_config: Optional[QuantizationConfig] = None,
intermediate_size: Optional[int] = None,
input_is_parallel: bool = True,
params_dtype: Optional[torch.dtype] = None,
) -> nn.Module:
"""Get an activation-and-mul (i.e. SiluAndMul) function by name.""" """Get an activation-and-mul (i.e. SiluAndMul) function by name."""
act_fn_name = act_fn_name.lower() act_fn_name = act_fn_name.lower()
if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY: if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
raise ValueError( raise ValueError(
f"Activation function {act_fn_name!r} is not supported.") f"Activation function {act_fn_name!r} is not supported.")
act_fn = _ACTIVATION_AND_MUL_REGISTRY[act_fn_name] return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]
if (quant_config is not None
and act_fn_name in quant_config.get_scaled_act_names()):
if intermediate_size is None:
raise ValueError("intermediate_size must be specified for scaled "
"activation functions.")
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
params_dtype)
return act_fn

View File

@ -213,9 +213,6 @@ class AQLMConfig(QuantizationConfig):
return AQLMLinearMethod(self) return AQLMLinearMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]:
return []
class AQLMLinearMethod(LinearMethodBase): class AQLMLinearMethod(LinearMethodBase):
"""Linear method for AQLM. """Linear method for AQLM.

View File

@ -77,9 +77,6 @@ class AWQConfig(QuantizationConfig):
return AWQLinearMethod(self) return AWQLinearMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]): def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
return any(module_name in prefix for module_name in modules_to_not_convert) return any(module_name in prefix for module_name in modules_to_not_convert)

View File

@ -127,9 +127,6 @@ class AWQMarlinConfig(QuantizationConfig):
return AWQMoEMethod(self) return AWQMoEMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]:
return []
@classmethod @classmethod
def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]): def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config. # Extract data from quant config.

View File

@ -133,11 +133,3 @@ class QuantizationConfig(ABC):
method. method.
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def get_scaled_act_names(self) -> List[str]:
"""Returns the activation function names that should be post-scaled.
For now, this is only used by AWQ.
"""
raise NotImplementedError

View File

@ -114,9 +114,6 @@ class BitsAndBytesConfig(QuantizationConfig):
return BitsAndBytesLinearMethod(self) return BitsAndBytesLinearMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]:
return []
def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: List[str]): def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: List[str]):
# Split the prefix into its dot-separated components # Split the prefix into its dot-separated components

View File

@ -45,9 +45,6 @@ class CompressedTensorsConfig(QuantizationConfig):
def get_linear_method(self) -> "CompressedTensorsLinearMethod": def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self) return CompressedTensorsLinearMethod(self)
def get_scaled_act_names(self) -> List[str]:
return []
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.float16, torch.bfloat16] return [torch.float16, torch.bfloat16]

View File

@ -50,9 +50,6 @@ class DeepSpeedFPConfig(QuantizationConfig):
def get_linear_method(self) -> "DeepSpeedFPLinearMethod": def get_linear_method(self) -> "DeepSpeedFPLinearMethod":
return DeepSpeedFPLinearMethod(self) return DeepSpeedFPLinearMethod(self)
def get_scaled_act_names(self) -> List[str]:
return []
@classmethod @classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half, torch.bfloat16] return [torch.half, torch.bfloat16]

View File

@ -45,9 +45,6 @@ class ExpertsInt8Config(QuantizationConfig):
return ExpertsInt8MoEMethod(self) return ExpertsInt8MoEMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]:
return []
class ExpertsInt8MoEMethod(FusedMoEMethodBase): class ExpertsInt8MoEMethod(FusedMoEMethodBase):

View File

@ -64,9 +64,6 @@ class FBGEMMFp8Config(QuantizationConfig):
return FBGEMMFp8LinearMethod(self) return FBGEMMFp8LinearMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]:
return []
class FBGEMMFp8LinearMethod(LinearMethodBase): class FBGEMMFp8LinearMethod(LinearMethodBase):

View File

@ -92,9 +92,6 @@ class Fp8Config(QuantizationConfig):
return Fp8KVCacheMethod(self) return Fp8KVCacheMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]:
return []
class Fp8LinearMethod(LinearMethodBase): class Fp8LinearMethod(LinearMethodBase):
"""Linear method for FP8. """Linear method for FP8.

View File

@ -48,9 +48,6 @@ class GGUFConfig(QuantizationConfig):
return GGUFEmbeddingMethod(self) return GGUFEmbeddingMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]:
return []
def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor, def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
qweight_type: int) -> torch.Tensor: qweight_type: int) -> torch.Tensor:

View File

@ -80,9 +80,6 @@ class GPTQConfig(QuantizationConfig):
return GPTQLinearMethod(self) return GPTQLinearMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]:
return []
class ExllamaState(Enum): class ExllamaState(Enum):

View File

@ -125,9 +125,6 @@ class GPTQMarlinConfig(QuantizationConfig):
return GPTQMarlinMoEMethod(self) return GPTQMarlinMoEMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]:
return []
@classmethod @classmethod
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config. # Extract data from quant config.

View File

@ -127,9 +127,6 @@ class GPTQMarlin24Config(QuantizationConfig):
return GPTQMarlin24LinearMethod(self) return GPTQMarlin24LinearMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]:
return []
class GPTQMarlin24LinearMethod(LinearMethodBase): class GPTQMarlin24LinearMethod(LinearMethodBase):
"""Linear method for Marlin24. """Linear method for Marlin24.

View File

@ -93,12 +93,6 @@ class IPEXConfig(QuantizationConfig):
return self.quant_method(self) return self.quant_method(self)
return None return None
def get_scaled_act_names(self) -> List[str]:
if self.method == "awq":
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
else:
return []
class IPEXAWQLinearMethod(AWQLinearMethod): class IPEXAWQLinearMethod(AWQLinearMethod):
"""AWQ linear method using IPEX for the CPU backend. """AWQ linear method using IPEX for the CPU backend.

View File

@ -110,9 +110,6 @@ class MarlinConfig(QuantizationConfig):
return MarlinLinearMethod(self) return MarlinLinearMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]:
return []
class MarlinLinearMethod(LinearMethodBase): class MarlinLinearMethod(LinearMethodBase):
"""Linear method for Marlin. """Linear method for Marlin.

View File

@ -68,9 +68,6 @@ class ModelOptFp8Config(QuantizationConfig):
return ModelOptFp8KVCacheMethod(self) return ModelOptFp8KVCacheMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]:
return []
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod): class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
""" """

View File

@ -57,9 +57,6 @@ class NeuronQuantConfig(QuantizationConfig):
"Neuron Quantization is only supported through" "Neuron Quantization is only supported through"
" transformers_neuronx.") " transformers_neuronx.")
def get_scaled_act_names(self) -> List[str]:
return []
def get_quantization_config(self): def get_quantization_config(self):
from transformers_neuronx.config import QuantizationConfig from transformers_neuronx.config import QuantizationConfig
return QuantizationConfig(quant_dtype=self.quant_dtype, return QuantizationConfig(quant_dtype=self.quant_dtype,

View File

@ -112,9 +112,6 @@ class QQQConfig(QuantizationConfig):
return QQQLinearMethod(self) return QQQLinearMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]:
return []
class QQQLinearMethod(LinearMethodBase): class QQQLinearMethod(LinearMethodBase):
"""Linear method for QQQ. """Linear method for QQQ.

View File

@ -50,9 +50,6 @@ class Int8TpuConfig(QuantizationConfig):
return TPUInt8LinearMethod(self) return TPUInt8LinearMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]:
return []
class TPUInt8LinearMethod(LinearMethodBase): class TPUInt8LinearMethod(LinearMethodBase):
"""Int8 Linear method for TPU Quant. """ """Int8 Linear method for TPU Quant. """

View File

@ -393,8 +393,7 @@ class BartEncoderLayer(nn.Module):
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.activation_fn = get_act_fn(config.activation_function, self.activation_fn = get_act_fn(config.activation_function)
quant_config)
ffn_hidden_size = self.embed_dim ffn_hidden_size = self.embed_dim
ffn_intermediate_size = config.encoder_ffn_dim ffn_intermediate_size = config.encoder_ffn_dim
@ -405,7 +404,7 @@ class BartEncoderLayer(nn.Module):
bias=ffn_has_bias, bias=ffn_has_bias,
quant_config=quant_config, quant_config=quant_config,
) )
self.act = get_act_fn("gelu", quant_config, ffn_intermediate_size) self.act = get_act_fn("gelu")
self.fc2 = RowParallelLinear( self.fc2 = RowParallelLinear(
ffn_intermediate_size, ffn_intermediate_size,
ffn_hidden_size, ffn_hidden_size,
@ -473,8 +472,7 @@ class BartDecoderLayer(nn.Module):
config=config, config=config,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config) quant_config=quant_config)
self.activation_fn = get_act_fn(config.activation_function, self.activation_fn = get_act_fn(config.activation_function)
quant_config)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
''' '''

View File

@ -146,7 +146,7 @@ class BloomMLP(nn.Module):
4 * hidden_size, 4 * hidden_size,
quant_config=quant_config, quant_config=quant_config,
) )
self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size) self.gelu_impl = get_act_fn("gelu")
self.dense_4h_to_h = RowParallelLinear( self.dense_4h_to_h = RowParallelLinear(
4 * hidden_size, 4 * hidden_size,
hidden_size, hidden_size,

View File

@ -212,7 +212,7 @@ class FalconMLP(nn.Module):
bias=config.bias, bias=config.bias,
skip_bias_add=True, skip_bias_add=True,
quant_config=quant_config) quant_config=quant_config)
self.act = get_act_fn("gelu", quant_config, 4 * hidden_size) self.act = get_act_fn("gelu")
self.reduce_row_parallel_results = not (config.new_decoder_architecture self.reduce_row_parallel_results = not (config.new_decoder_architecture
or config.parallel_attn) or config.parallel_attn)
self.dense_4h_to_h = RowParallelLinear( self.dense_4h_to_h = RowParallelLinear(

View File

@ -123,8 +123,7 @@ class GPT2MLP(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.c_proj", prefix=f"{prefix}.c_proj",
) )
self.act = get_act_fn(config.activation_function, quant_config, self.act = get_act_fn(config.activation_function)
intermediate_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states) hidden_states, _ = self.c_fc(hidden_states)

View File

@ -135,8 +135,7 @@ class GPTBigMLP(nn.Module):
bias=True, bias=True,
quant_config=quant_config, quant_config=quant_config,
) )
self.act = get_act_fn(config.activation_function, quant_config, self.act = get_act_fn(config.activation_function)
intermediate_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states) hidden_states, _ = self.c_fc(hidden_states)

View File

@ -130,8 +130,7 @@ class GPTJMLP(nn.Module):
hidden_size, hidden_size,
quant_config=quant_config, quant_config=quant_config,
) )
self.act = get_act_fn(config.activation_function, quant_config, self.act = get_act_fn(config.activation_function)
intermediate_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.fc_in(hidden_states) hidden_states, _ = self.fc_in(hidden_states)

View File

@ -128,8 +128,7 @@ class GPTNeoXMLP(nn.Module):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
) )
self.act = get_act_fn(config.hidden_act, quant_config, self.act = get_act_fn(config.hidden_act)
config.intermediate_size)
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states, _ = self.dense_h_to_4h(hidden_states) hidden_states, _ = self.dense_h_to_4h(hidden_states)

View File

@ -153,7 +153,7 @@ class MPTMLP(nn.Module):
bias=not config.no_bias, bias=not config.no_bias,
quant_config=quant_config, quant_config=quant_config,
) )
self.act = get_act_fn("gelu", quant_config, intermediate_size) self.act = get_act_fn("gelu")
self.down_proj = RowParallelLinear( self.down_proj = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,

View File

@ -147,8 +147,7 @@ class OPTDecoderLayer(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.fc1", prefix=f"{prefix}.fc1",
) )
self.activation_fn = get_act_fn(config.activation_function, self.activation_fn = get_act_fn(config.activation_function)
quant_config, config.ffn_dim)
self.fc2 = RowParallelLinear( self.fc2 = RowParallelLinear(
config.ffn_dim, config.ffn_dim,
self.embed_dim, self.embed_dim,

View File

@ -60,7 +60,7 @@ class PersimmonMLP(nn.Module):
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
config.hidden_size, config.hidden_size,
quant_config=quant_config) quant_config=quant_config)
self.act = get_act_fn(config.hidden_act, quant_config) self.act = get_act_fn(config.hidden_act)
def forward(self, hidden_states) -> torch.Tensor: def forward(self, hidden_states) -> torch.Tensor:
hidden_states, _ = self.dense_h_to_4h(hidden_states) hidden_states, _ = self.dense_h_to_4h(hidden_states)

View File

@ -152,7 +152,7 @@ class PhiMLP(nn.Module):
config.hidden_size, config.hidden_size,
quant_config=quant_config, quant_config=quant_config,
) )
self.act = get_act_fn(config.hidden_act, quant_config, n_inner) self.act = get_act_fn(config.hidden_act)
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states, _ = self.fc1(hidden_states) hidden_states, _ = self.fc1(hidden_states)

View File

@ -203,7 +203,7 @@ class QwenVMLP(nn.Module):
intermediate_size, intermediate_size,
bias=True, bias=True,
quant_config=quant_config) quant_config=quant_config)
self.act_fn = get_act_fn("gelu", quant_config, intermediate_size) self.act_fn = get_act_fn("gelu")
self.c_proj = RowParallelLinear( self.c_proj = RowParallelLinear(
intermediate_size, intermediate_size,
hidden_size, hidden_size,

View File

@ -139,8 +139,7 @@ class Starcoder2MLP(nn.Module):
bias=config.use_bias, bias=config.use_bias,
quant_config=quant_config, quant_config=quant_config,
) )
self.act = get_act_fn(config.hidden_act, quant_config, self.act = get_act_fn(config.hidden_act)
config.intermediate_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states) hidden_states, _ = self.c_fc(hidden_states)