Remove ScaledActivation for AWQ (#10057)
Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
parent
406d4cc480
commit
399c798608
@ -9,7 +9,6 @@ import torch.nn.functional as F
|
|||||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size)
|
get_tensor_model_parallel_world_size)
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.utils import LazyDict
|
from vllm.utils import LazyDict
|
||||||
|
|
||||||
@ -277,28 +276,14 @@ _ACTIVATION_REGISTRY = LazyDict({
|
|||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def get_act_fn(
|
def get_act_fn(act_fn_name: str) -> nn.Module:
|
||||||
act_fn_name: str,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
intermediate_size: Optional[int] = None,
|
|
||||||
input_is_parallel: bool = True,
|
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
|
||||||
) -> nn.Module:
|
|
||||||
"""Get an activation function by name."""
|
"""Get an activation function by name."""
|
||||||
act_fn_name = act_fn_name.lower()
|
act_fn_name = act_fn_name.lower()
|
||||||
if act_fn_name not in _ACTIVATION_REGISTRY:
|
if act_fn_name not in _ACTIVATION_REGISTRY:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Activation function {act_fn_name!r} is not supported.")
|
f"Activation function {act_fn_name!r} is not supported.")
|
||||||
|
|
||||||
act_fn = _ACTIVATION_REGISTRY[act_fn_name]
|
return _ACTIVATION_REGISTRY[act_fn_name]
|
||||||
if (quant_config is not None
|
|
||||||
and act_fn_name in quant_config.get_scaled_act_names()):
|
|
||||||
if intermediate_size is None:
|
|
||||||
raise ValueError("intermediate_size must be specified for scaled "
|
|
||||||
"activation functions.")
|
|
||||||
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
|
|
||||||
params_dtype)
|
|
||||||
return act_fn
|
|
||||||
|
|
||||||
|
|
||||||
_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
|
_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
|
||||||
@ -307,25 +292,11 @@ _ACTIVATION_AND_MUL_REGISTRY = LazyDict({
|
|||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def get_act_and_mul_fn(
|
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
|
||||||
act_fn_name: str,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
intermediate_size: Optional[int] = None,
|
|
||||||
input_is_parallel: bool = True,
|
|
||||||
params_dtype: Optional[torch.dtype] = None,
|
|
||||||
) -> nn.Module:
|
|
||||||
"""Get an activation-and-mul (i.e. SiluAndMul) function by name."""
|
"""Get an activation-and-mul (i.e. SiluAndMul) function by name."""
|
||||||
act_fn_name = act_fn_name.lower()
|
act_fn_name = act_fn_name.lower()
|
||||||
if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
|
if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Activation function {act_fn_name!r} is not supported.")
|
f"Activation function {act_fn_name!r} is not supported.")
|
||||||
|
|
||||||
act_fn = _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]
|
return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]
|
||||||
if (quant_config is not None
|
|
||||||
and act_fn_name in quant_config.get_scaled_act_names()):
|
|
||||||
if intermediate_size is None:
|
|
||||||
raise ValueError("intermediate_size must be specified for scaled "
|
|
||||||
"activation functions.")
|
|
||||||
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
|
|
||||||
params_dtype)
|
|
||||||
return act_fn
|
|
||||||
|
@ -213,9 +213,6 @@ class AQLMConfig(QuantizationConfig):
|
|||||||
return AQLMLinearMethod(self)
|
return AQLMLinearMethod(self)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
class AQLMLinearMethod(LinearMethodBase):
|
class AQLMLinearMethod(LinearMethodBase):
|
||||||
"""Linear method for AQLM.
|
"""Linear method for AQLM.
|
||||||
|
@ -77,9 +77,6 @@ class AWQConfig(QuantizationConfig):
|
|||||||
return AWQLinearMethod(self)
|
return AWQLinearMethod(self)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
|
||||||
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
|
|
||||||
|
|
||||||
|
|
||||||
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
|
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
|
||||||
return any(module_name in prefix for module_name in modules_to_not_convert)
|
return any(module_name in prefix for module_name in modules_to_not_convert)
|
||||||
|
@ -127,9 +127,6 @@ class AWQMarlinConfig(QuantizationConfig):
|
|||||||
return AWQMoEMethod(self)
|
return AWQMoEMethod(self)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
||||||
# Extract data from quant config.
|
# Extract data from quant config.
|
||||||
|
@ -133,11 +133,3 @@ class QuantizationConfig(ABC):
|
|||||||
method.
|
method.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
|
||||||
"""Returns the activation function names that should be post-scaled.
|
|
||||||
|
|
||||||
For now, this is only used by AWQ.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
@ -114,9 +114,6 @@ class BitsAndBytesConfig(QuantizationConfig):
|
|||||||
return BitsAndBytesLinearMethod(self)
|
return BitsAndBytesLinearMethod(self)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: List[str]):
|
def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: List[str]):
|
||||||
# Split the prefix into its dot-separated components
|
# Split the prefix into its dot-separated components
|
||||||
|
@ -45,9 +45,6 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
|
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
|
||||||
return CompressedTensorsLinearMethod(self)
|
return CompressedTensorsLinearMethod(self)
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||||
return [torch.float16, torch.bfloat16]
|
return [torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
|
@ -50,9 +50,6 @@ class DeepSpeedFPConfig(QuantizationConfig):
|
|||||||
def get_linear_method(self) -> "DeepSpeedFPLinearMethod":
|
def get_linear_method(self) -> "DeepSpeedFPLinearMethod":
|
||||||
return DeepSpeedFPLinearMethod(self)
|
return DeepSpeedFPLinearMethod(self)
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||||
return [torch.half, torch.bfloat16]
|
return [torch.half, torch.bfloat16]
|
||||||
|
@ -45,9 +45,6 @@ class ExpertsInt8Config(QuantizationConfig):
|
|||||||
return ExpertsInt8MoEMethod(self)
|
return ExpertsInt8MoEMethod(self)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
class ExpertsInt8MoEMethod(FusedMoEMethodBase):
|
||||||
|
|
||||||
|
@ -64,9 +64,6 @@ class FBGEMMFp8Config(QuantizationConfig):
|
|||||||
return FBGEMMFp8LinearMethod(self)
|
return FBGEMMFp8LinearMethod(self)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
class FBGEMMFp8LinearMethod(LinearMethodBase):
|
class FBGEMMFp8LinearMethod(LinearMethodBase):
|
||||||
|
|
||||||
|
@ -92,9 +92,6 @@ class Fp8Config(QuantizationConfig):
|
|||||||
return Fp8KVCacheMethod(self)
|
return Fp8KVCacheMethod(self)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
class Fp8LinearMethod(LinearMethodBase):
|
class Fp8LinearMethod(LinearMethodBase):
|
||||||
"""Linear method for FP8.
|
"""Linear method for FP8.
|
||||||
|
@ -48,9 +48,6 @@ class GGUFConfig(QuantizationConfig):
|
|||||||
return GGUFEmbeddingMethod(self)
|
return GGUFEmbeddingMethod(self)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
|
def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
|
||||||
qweight_type: int) -> torch.Tensor:
|
qweight_type: int) -> torch.Tensor:
|
||||||
|
@ -80,9 +80,6 @@ class GPTQConfig(QuantizationConfig):
|
|||||||
return GPTQLinearMethod(self)
|
return GPTQLinearMethod(self)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
class ExllamaState(Enum):
|
class ExllamaState(Enum):
|
||||||
|
|
||||||
|
@ -125,9 +125,6 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|||||||
return GPTQMarlinMoEMethod(self)
|
return GPTQMarlinMoEMethod(self)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
||||||
# Extract data from quant config.
|
# Extract data from quant config.
|
||||||
|
@ -127,9 +127,6 @@ class GPTQMarlin24Config(QuantizationConfig):
|
|||||||
return GPTQMarlin24LinearMethod(self)
|
return GPTQMarlin24LinearMethod(self)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
class GPTQMarlin24LinearMethod(LinearMethodBase):
|
class GPTQMarlin24LinearMethod(LinearMethodBase):
|
||||||
"""Linear method for Marlin24.
|
"""Linear method for Marlin24.
|
||||||
|
@ -93,12 +93,6 @@ class IPEXConfig(QuantizationConfig):
|
|||||||
return self.quant_method(self)
|
return self.quant_method(self)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
|
||||||
if self.method == "awq":
|
|
||||||
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
|
|
||||||
else:
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
class IPEXAWQLinearMethod(AWQLinearMethod):
|
class IPEXAWQLinearMethod(AWQLinearMethod):
|
||||||
"""AWQ linear method using IPEX for the CPU backend.
|
"""AWQ linear method using IPEX for the CPU backend.
|
||||||
|
@ -110,9 +110,6 @@ class MarlinConfig(QuantizationConfig):
|
|||||||
return MarlinLinearMethod(self)
|
return MarlinLinearMethod(self)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
class MarlinLinearMethod(LinearMethodBase):
|
class MarlinLinearMethod(LinearMethodBase):
|
||||||
"""Linear method for Marlin.
|
"""Linear method for Marlin.
|
||||||
|
@ -68,9 +68,6 @@ class ModelOptFp8Config(QuantizationConfig):
|
|||||||
return ModelOptFp8KVCacheMethod(self)
|
return ModelOptFp8KVCacheMethod(self)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
|
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
|
||||||
"""
|
"""
|
||||||
|
@ -57,9 +57,6 @@ class NeuronQuantConfig(QuantizationConfig):
|
|||||||
"Neuron Quantization is only supported through"
|
"Neuron Quantization is only supported through"
|
||||||
" transformers_neuronx.")
|
" transformers_neuronx.")
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
def get_quantization_config(self):
|
def get_quantization_config(self):
|
||||||
from transformers_neuronx.config import QuantizationConfig
|
from transformers_neuronx.config import QuantizationConfig
|
||||||
return QuantizationConfig(quant_dtype=self.quant_dtype,
|
return QuantizationConfig(quant_dtype=self.quant_dtype,
|
||||||
|
@ -112,9 +112,6 @@ class QQQConfig(QuantizationConfig):
|
|||||||
return QQQLinearMethod(self)
|
return QQQLinearMethod(self)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
class QQQLinearMethod(LinearMethodBase):
|
class QQQLinearMethod(LinearMethodBase):
|
||||||
"""Linear method for QQQ.
|
"""Linear method for QQQ.
|
||||||
|
@ -50,9 +50,6 @@ class Int8TpuConfig(QuantizationConfig):
|
|||||||
return TPUInt8LinearMethod(self)
|
return TPUInt8LinearMethod(self)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_scaled_act_names(self) -> List[str]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
class TPUInt8LinearMethod(LinearMethodBase):
|
class TPUInt8LinearMethod(LinearMethodBase):
|
||||||
"""Int8 Linear method for TPU Quant. """
|
"""Int8 Linear method for TPU Quant. """
|
||||||
|
@ -393,8 +393,7 @@ class BartEncoderLayer(nn.Module):
|
|||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config)
|
||||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||||
self.activation_fn = get_act_fn(config.activation_function,
|
self.activation_fn = get_act_fn(config.activation_function)
|
||||||
quant_config)
|
|
||||||
|
|
||||||
ffn_hidden_size = self.embed_dim
|
ffn_hidden_size = self.embed_dim
|
||||||
ffn_intermediate_size = config.encoder_ffn_dim
|
ffn_intermediate_size = config.encoder_ffn_dim
|
||||||
@ -405,7 +404,7 @@ class BartEncoderLayer(nn.Module):
|
|||||||
bias=ffn_has_bias,
|
bias=ffn_has_bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.act = get_act_fn("gelu", quant_config, ffn_intermediate_size)
|
self.act = get_act_fn("gelu")
|
||||||
self.fc2 = RowParallelLinear(
|
self.fc2 = RowParallelLinear(
|
||||||
ffn_intermediate_size,
|
ffn_intermediate_size,
|
||||||
ffn_hidden_size,
|
ffn_hidden_size,
|
||||||
@ -473,8 +472,7 @@ class BartDecoderLayer(nn.Module):
|
|||||||
config=config,
|
config=config,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config)
|
||||||
self.activation_fn = get_act_fn(config.activation_function,
|
self.activation_fn = get_act_fn(config.activation_function)
|
||||||
quant_config)
|
|
||||||
|
|
||||||
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||||
'''
|
'''
|
||||||
|
@ -146,7 +146,7 @@ class BloomMLP(nn.Module):
|
|||||||
4 * hidden_size,
|
4 * hidden_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size)
|
self.gelu_impl = get_act_fn("gelu")
|
||||||
self.dense_4h_to_h = RowParallelLinear(
|
self.dense_4h_to_h = RowParallelLinear(
|
||||||
4 * hidden_size,
|
4 * hidden_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
|
@ -212,7 +212,7 @@ class FalconMLP(nn.Module):
|
|||||||
bias=config.bias,
|
bias=config.bias,
|
||||||
skip_bias_add=True,
|
skip_bias_add=True,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config)
|
||||||
self.act = get_act_fn("gelu", quant_config, 4 * hidden_size)
|
self.act = get_act_fn("gelu")
|
||||||
self.reduce_row_parallel_results = not (config.new_decoder_architecture
|
self.reduce_row_parallel_results = not (config.new_decoder_architecture
|
||||||
or config.parallel_attn)
|
or config.parallel_attn)
|
||||||
self.dense_4h_to_h = RowParallelLinear(
|
self.dense_4h_to_h = RowParallelLinear(
|
||||||
|
@ -123,8 +123,7 @@ class GPT2MLP(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.c_proj",
|
prefix=f"{prefix}.c_proj",
|
||||||
)
|
)
|
||||||
self.act = get_act_fn(config.activation_function, quant_config,
|
self.act = get_act_fn(config.activation_function)
|
||||||
intermediate_size)
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states, _ = self.c_fc(hidden_states)
|
hidden_states, _ = self.c_fc(hidden_states)
|
||||||
|
@ -135,8 +135,7 @@ class GPTBigMLP(nn.Module):
|
|||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.act = get_act_fn(config.activation_function, quant_config,
|
self.act = get_act_fn(config.activation_function)
|
||||||
intermediate_size)
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states, _ = self.c_fc(hidden_states)
|
hidden_states, _ = self.c_fc(hidden_states)
|
||||||
|
@ -130,8 +130,7 @@ class GPTJMLP(nn.Module):
|
|||||||
hidden_size,
|
hidden_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.act = get_act_fn(config.activation_function, quant_config,
|
self.act = get_act_fn(config.activation_function)
|
||||||
intermediate_size)
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states, _ = self.fc_in(hidden_states)
|
hidden_states, _ = self.fc_in(hidden_states)
|
||||||
|
@ -128,8 +128,7 @@ class GPTNeoXMLP(nn.Module):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.act = get_act_fn(config.hidden_act, quant_config,
|
self.act = get_act_fn(config.hidden_act)
|
||||||
config.intermediate_size)
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
hidden_states, _ = self.dense_h_to_4h(hidden_states)
|
hidden_states, _ = self.dense_h_to_4h(hidden_states)
|
||||||
|
@ -153,7 +153,7 @@ class MPTMLP(nn.Module):
|
|||||||
bias=not config.no_bias,
|
bias=not config.no_bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.act = get_act_fn("gelu", quant_config, intermediate_size)
|
self.act = get_act_fn("gelu")
|
||||||
self.down_proj = RowParallelLinear(
|
self.down_proj = RowParallelLinear(
|
||||||
intermediate_size,
|
intermediate_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
|
@ -147,8 +147,7 @@ class OPTDecoderLayer(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.fc1",
|
prefix=f"{prefix}.fc1",
|
||||||
)
|
)
|
||||||
self.activation_fn = get_act_fn(config.activation_function,
|
self.activation_fn = get_act_fn(config.activation_function)
|
||||||
quant_config, config.ffn_dim)
|
|
||||||
self.fc2 = RowParallelLinear(
|
self.fc2 = RowParallelLinear(
|
||||||
config.ffn_dim,
|
config.ffn_dim,
|
||||||
self.embed_dim,
|
self.embed_dim,
|
||||||
|
@ -60,7 +60,7 @@ class PersimmonMLP(nn.Module):
|
|||||||
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
|
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config)
|
||||||
self.act = get_act_fn(config.hidden_act, quant_config)
|
self.act = get_act_fn(config.hidden_act)
|
||||||
|
|
||||||
def forward(self, hidden_states) -> torch.Tensor:
|
def forward(self, hidden_states) -> torch.Tensor:
|
||||||
hidden_states, _ = self.dense_h_to_4h(hidden_states)
|
hidden_states, _ = self.dense_h_to_4h(hidden_states)
|
||||||
|
@ -152,7 +152,7 @@ class PhiMLP(nn.Module):
|
|||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.act = get_act_fn(config.hidden_act, quant_config, n_inner)
|
self.act = get_act_fn(config.hidden_act)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
hidden_states, _ = self.fc1(hidden_states)
|
hidden_states, _ = self.fc1(hidden_states)
|
||||||
|
@ -203,7 +203,7 @@ class QwenVMLP(nn.Module):
|
|||||||
intermediate_size,
|
intermediate_size,
|
||||||
bias=True,
|
bias=True,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config)
|
||||||
self.act_fn = get_act_fn("gelu", quant_config, intermediate_size)
|
self.act_fn = get_act_fn("gelu")
|
||||||
self.c_proj = RowParallelLinear(
|
self.c_proj = RowParallelLinear(
|
||||||
intermediate_size,
|
intermediate_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
|
@ -139,8 +139,7 @@ class Starcoder2MLP(nn.Module):
|
|||||||
bias=config.use_bias,
|
bias=config.use_bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.act = get_act_fn(config.hidden_act, quant_config,
|
self.act = get_act_fn(config.hidden_act)
|
||||||
config.intermediate_size)
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
hidden_states, _ = self.c_fc(hidden_states)
|
hidden_states, _ = self.c_fc(hidden_states)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user