Remove ScaledActivation for AWQ (#10057)
Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
parent
406d4cc480
commit
399c798608
@ -9,7 +9,6 @@ import torch.nn.functional as F
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils import LazyDict
|
||||
|
||||
@ -277,28 +276,14 @@ _ACTIVATION_REGISTRY = LazyDict({
|
||||
})
|
||||
|
||||
|
||||
def get_act_fn(
|
||||
act_fn_name: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
intermediate_size: Optional[int] = None,
|
||||
input_is_parallel: bool = True,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
) -> nn.Module:
|
||||
def get_act_fn(act_fn_name: str) -> nn.Module:
|
||||
"""Get an activation function by name."""
|
||||
act_fn_name = act_fn_name.lower()
|
||||
if act_fn_name not in _ACTIVATION_REGISTRY:
|
||||
raise ValueError(
|
||||
f"Activation function {act_fn_name!r} is not supported.")
|
||||
|
||||
act_fn = _ACTIVATION_REGISTRY[act_fn_name]
|
||||
if (quant_config is not None
|
||||
and act_fn_name in quant_config.get_scaled_act_names()):
|
||||
if intermediate_size is None:
|
||||
raise ValueError("intermediate_size must be specified for scaled "
|
||||
"activation functions.")
|
||||
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
|
||||
params_dtype)
|
||||
return act_fn
|
||||
return _ACTIVATION_REGISTRY[act_fn_name]
|
||||
|
||||
|
||||
_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
|
||||
@ -307,25 +292,11 @@ _ACTIVATION_AND_MUL_REGISTRY = LazyDict({
|
||||
})
|
||||
|
||||
|
||||
def get_act_and_mul_fn(
|
||||
act_fn_name: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
intermediate_size: Optional[int] = None,
|
||||
input_is_parallel: bool = True,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
) -> nn.Module:
|
||||
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
|
||||
"""Get an activation-and-mul (i.e. SiluAndMul) function by name."""
|
||||
act_fn_name = act_fn_name.lower()
|
||||
if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
|
||||
raise ValueError(
|
||||
f"Activation function {act_fn_name!r} is not supported.")
|
||||
|
||||
act_fn = _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]
|
||||
if (quant_config is not None
|
||||
and act_fn_name in quant_config.get_scaled_act_names()):
|
||||
if intermediate_size is None:
|
||||
raise ValueError("intermediate_size must be specified for scaled "
|
||||
"activation functions.")
|
||||
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
|
||||
params_dtype)
|
||||
return act_fn
|
||||
return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]
|
||||
|
@ -213,9 +213,6 @@ class AQLMConfig(QuantizationConfig):
|
||||
return AQLMLinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class AQLMLinearMethod(LinearMethodBase):
|
||||
"""Linear method for AQLM.
|
||||
|
@ -77,9 +77,6 @@ class AWQConfig(QuantizationConfig):
|
||||
return AWQLinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
|
||||
|
||||
|
||||
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
|
||||
return any(module_name in prefix for module_name in modules_to_not_convert)
|
||||
|
@ -127,9 +127,6 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
return AWQMoEMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
|
@ -133,11 +133,3 @@ class QuantizationConfig(ABC):
|
||||
method.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
"""Returns the activation function names that should be post-scaled.
|
||||
|
||||
For now, this is only used by AWQ.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
@ -114,9 +114,6 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
return BitsAndBytesLinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: List[str]):
|
||||
# Split the prefix into its dot-separated components
|
||||
|
@ -45,9 +45,6 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
|
||||
return CompressedTensorsLinearMethod(self)
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
|
@ -50,9 +50,6 @@ class DeepSpeedFPConfig(QuantizationConfig):
|
||||
def get_linear_method(self) -> "DeepSpeedFPLinearMethod":
|
||||
return DeepSpeedFPLinearMethod(self)
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
@ -45,9 +45,6 @@ class ExpertsInt8Config(QuantizationConfig):
|
||||
return ExpertsInt8MoEMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
|
@ -64,9 +64,6 @@ class FBGEMMFp8Config(QuantizationConfig):
|
||||
return FBGEMMFp8LinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class FBGEMMFp8LinearMethod(LinearMethodBase):
|
||||
|
||||
|
@ -92,9 +92,6 @@ class Fp8Config(QuantizationConfig):
|
||||
return Fp8KVCacheMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class Fp8LinearMethod(LinearMethodBase):
|
||||
"""Linear method for FP8.
|
||||
|
@ -48,9 +48,6 @@ class GGUFConfig(QuantizationConfig):
|
||||
return GGUFEmbeddingMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
|
||||
qweight_type: int) -> torch.Tensor:
|
||||
|
@ -80,9 +80,6 @@ class GPTQConfig(QuantizationConfig):
|
||||
return GPTQLinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class ExllamaState(Enum):
|
||||
|
||||
|
@ -125,9 +125,6 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
return GPTQMarlinMoEMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
|
@ -127,9 +127,6 @@ class GPTQMarlin24Config(QuantizationConfig):
|
||||
return GPTQMarlin24LinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class GPTQMarlin24LinearMethod(LinearMethodBase):
|
||||
"""Linear method for Marlin24.
|
||||
|
@ -93,12 +93,6 @@ class IPEXConfig(QuantizationConfig):
|
||||
return self.quant_method(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
if self.method == "awq":
|
||||
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
class IPEXAWQLinearMethod(AWQLinearMethod):
|
||||
"""AWQ linear method using IPEX for the CPU backend.
|
||||
|
@ -110,9 +110,6 @@ class MarlinConfig(QuantizationConfig):
|
||||
return MarlinLinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class MarlinLinearMethod(LinearMethodBase):
|
||||
"""Linear method for Marlin.
|
||||
|
@ -68,9 +68,6 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
return ModelOptFp8KVCacheMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
|
||||
"""
|
||||
|
@ -57,9 +57,6 @@ class NeuronQuantConfig(QuantizationConfig):
|
||||
"Neuron Quantization is only supported through"
|
||||
" transformers_neuronx.")
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
def get_quantization_config(self):
|
||||
from transformers_neuronx.config import QuantizationConfig
|
||||
return QuantizationConfig(quant_dtype=self.quant_dtype,
|
||||
|
@ -112,9 +112,6 @@ class QQQConfig(QuantizationConfig):
|
||||
return QQQLinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class QQQLinearMethod(LinearMethodBase):
|
||||
"""Linear method for QQQ.
|
||||
|
@ -50,9 +50,6 @@ class Int8TpuConfig(QuantizationConfig):
|
||||
return TPUInt8LinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
class TPUInt8LinearMethod(LinearMethodBase):
|
||||
"""Int8 Linear method for TPU Quant. """
|
||||
|
@ -393,8 +393,7 @@ class BartEncoderLayer(nn.Module):
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
self.activation_fn = get_act_fn(config.activation_function,
|
||||
quant_config)
|
||||
self.activation_fn = get_act_fn(config.activation_function)
|
||||
|
||||
ffn_hidden_size = self.embed_dim
|
||||
ffn_intermediate_size = config.encoder_ffn_dim
|
||||
@ -405,7 +404,7 @@ class BartEncoderLayer(nn.Module):
|
||||
bias=ffn_has_bias,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.act = get_act_fn("gelu", quant_config, ffn_intermediate_size)
|
||||
self.act = get_act_fn("gelu")
|
||||
self.fc2 = RowParallelLinear(
|
||||
ffn_intermediate_size,
|
||||
ffn_hidden_size,
|
||||
@ -473,8 +472,7 @@ class BartDecoderLayer(nn.Module):
|
||||
config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
self.activation_fn = get_act_fn(config.activation_function,
|
||||
quant_config)
|
||||
self.activation_fn = get_act_fn(config.activation_function)
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||
'''
|
||||
|
@ -146,7 +146,7 @@ class BloomMLP(nn.Module):
|
||||
4 * hidden_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size)
|
||||
self.gelu_impl = get_act_fn("gelu")
|
||||
self.dense_4h_to_h = RowParallelLinear(
|
||||
4 * hidden_size,
|
||||
hidden_size,
|
||||
|
@ -212,7 +212,7 @@ class FalconMLP(nn.Module):
|
||||
bias=config.bias,
|
||||
skip_bias_add=True,
|
||||
quant_config=quant_config)
|
||||
self.act = get_act_fn("gelu", quant_config, 4 * hidden_size)
|
||||
self.act = get_act_fn("gelu")
|
||||
self.reduce_row_parallel_results = not (config.new_decoder_architecture
|
||||
or config.parallel_attn)
|
||||
self.dense_4h_to_h = RowParallelLinear(
|
||||
|
@ -123,8 +123,7 @@ class GPT2MLP(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
)
|
||||
self.act = get_act_fn(config.activation_function, quant_config,
|
||||
intermediate_size)
|
||||
self.act = get_act_fn(config.activation_function)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.c_fc(hidden_states)
|
||||
|
@ -135,8 +135,7 @@ class GPTBigMLP(nn.Module):
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.act = get_act_fn(config.activation_function, quant_config,
|
||||
intermediate_size)
|
||||
self.act = get_act_fn(config.activation_function)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.c_fc(hidden_states)
|
||||
|
@ -130,8 +130,7 @@ class GPTJMLP(nn.Module):
|
||||
hidden_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.act = get_act_fn(config.activation_function, quant_config,
|
||||
intermediate_size)
|
||||
self.act = get_act_fn(config.activation_function)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.fc_in(hidden_states)
|
||||
|
@ -128,8 +128,7 @@ class GPTNeoXMLP(nn.Module):
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.act = get_act_fn(config.hidden_act, quant_config,
|
||||
config.intermediate_size)
|
||||
self.act = get_act_fn(config.hidden_act)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states, _ = self.dense_h_to_4h(hidden_states)
|
||||
|
@ -153,7 +153,7 @@ class MPTMLP(nn.Module):
|
||||
bias=not config.no_bias,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.act = get_act_fn("gelu", quant_config, intermediate_size)
|
||||
self.act = get_act_fn("gelu")
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
|
@ -147,8 +147,7 @@ class OPTDecoderLayer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fc1",
|
||||
)
|
||||
self.activation_fn = get_act_fn(config.activation_function,
|
||||
quant_config, config.ffn_dim)
|
||||
self.activation_fn = get_act_fn(config.activation_function)
|
||||
self.fc2 = RowParallelLinear(
|
||||
config.ffn_dim,
|
||||
self.embed_dim,
|
||||
|
@ -60,7 +60,7 @@ class PersimmonMLP(nn.Module):
|
||||
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
self.act = get_act_fn(config.hidden_act, quant_config)
|
||||
self.act = get_act_fn(config.hidden_act)
|
||||
|
||||
def forward(self, hidden_states) -> torch.Tensor:
|
||||
hidden_states, _ = self.dense_h_to_4h(hidden_states)
|
||||
|
@ -152,7 +152,7 @@ class PhiMLP(nn.Module):
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.act = get_act_fn(config.hidden_act, quant_config, n_inner)
|
||||
self.act = get_act_fn(config.hidden_act)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states, _ = self.fc1(hidden_states)
|
||||
|
@ -203,7 +203,7 @@ class QwenVMLP(nn.Module):
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config)
|
||||
self.act_fn = get_act_fn("gelu", quant_config, intermediate_size)
|
||||
self.act_fn = get_act_fn("gelu")
|
||||
self.c_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
|
@ -139,8 +139,7 @@ class Starcoder2MLP(nn.Module):
|
||||
bias=config.use_bias,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.act = get_act_fn(config.hidden_act, quant_config,
|
||||
config.intermediate_size)
|
||||
self.act = get_act_fn(config.hidden_act)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.c_fc(hidden_states)
|
||||
|
Loading…
x
Reference in New Issue
Block a user