[Model][Bugfix] Add FATReLU activation and support for openbmb/MiniCPM-S-1B-sft (#9396)
This commit is contained in:
parent
fb60ae9b91
commit
5b8a1fde84
@ -159,7 +159,7 @@ Text Generation
|
|||||||
-
|
-
|
||||||
* - :code:`MiniCPMForCausalLM`
|
* - :code:`MiniCPMForCausalLM`
|
||||||
- MiniCPM
|
- MiniCPM
|
||||||
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc.
|
- :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, :code:`openbmb/MiniCPM-S-1B-sft`, etc.
|
||||||
- ✅︎
|
- ✅︎
|
||||||
- ✅︎
|
- ✅︎
|
||||||
* - :code:`MiniCPM3ForCausalLM`
|
* - :code:`MiniCPM3ForCausalLM`
|
||||||
|
@ -13,6 +13,33 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
|||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
|
||||||
|
|
||||||
|
class FatreluAndMul(CustomOp):
|
||||||
|
"""An activation function for FATReLU.
|
||||||
|
|
||||||
|
The function computes x -> FATReLU(x[:d]) * x[d:] where
|
||||||
|
d = x.shape[-1] // 2.
|
||||||
|
This is used in openbmb/MiniCPM-S-1B-sft.
|
||||||
|
|
||||||
|
Shapes:
|
||||||
|
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
||||||
|
return: (num_tokens, d) or (batch_size, seq_len, d)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, threshold: float = 0.):
|
||||||
|
super().__init__()
|
||||||
|
self.threshold = threshold
|
||||||
|
|
||||||
|
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
d = x.shape[-1] // 2
|
||||||
|
x1 = x[..., :d]
|
||||||
|
x2 = x[..., d:]
|
||||||
|
x1 = F.threshold(x1, self.threshold, 0.0)
|
||||||
|
return x1 * x2
|
||||||
|
|
||||||
|
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.forward_native(x)
|
||||||
|
|
||||||
|
|
||||||
class SiluAndMul(CustomOp):
|
class SiluAndMul(CustomOp):
|
||||||
"""An activation function for SwiGLU.
|
"""An activation function for SwiGLU.
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ from vllm.config import CacheConfig, LoRAConfig
|
|||||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
from vllm.model_executor.layers.activation import SiluAndMul
|
from vllm.model_executor.layers.activation import FatreluAndMul, SiluAndMul
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||||
@ -152,6 +152,7 @@ class MiniCPMMLP(nn.Module):
|
|||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
intermediate_size: int,
|
intermediate_size: int,
|
||||||
hidden_act: str,
|
hidden_act: str,
|
||||||
|
hidden_act_param: float,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -163,10 +164,13 @@ class MiniCPMMLP(nn.Module):
|
|||||||
hidden_size,
|
hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config)
|
||||||
if hidden_act != "silu":
|
if hidden_act == "silu":
|
||||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
|
||||||
"Only silu is supported for now.")
|
|
||||||
self.act_fn = SiluAndMul()
|
self.act_fn = SiluAndMul()
|
||||||
|
elif hidden_act == "fatrelu":
|
||||||
|
self.act_fn = FatreluAndMul(threshold=hidden_act_param)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||||
|
"Only silu and fatrelu are supported for now.")
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
gate_up, _ = self.gate_up_proj(x)
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
@ -304,6 +308,7 @@ class MiniCPMDecoderLayer(nn.Module):
|
|||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
intermediate_size=self.config.intermediate_size,
|
intermediate_size=self.config.intermediate_size,
|
||||||
hidden_act=self.config.hidden_act,
|
hidden_act=self.config.hidden_act,
|
||||||
|
hidden_act_param=getattr(self.config, "hidden_act_param", 0.),
|
||||||
quant_config=self.quant_config,
|
quant_config=self.quant_config,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user