diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 0ae76887..6bdb6235 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -111,6 +111,7 @@ class MixtralAttention(nn.Module): def __init__( self, + config: MixtralConfig, hidden_size: int, num_heads: int, num_kv_heads: int, @@ -136,7 +137,9 @@ class MixtralAttention(nn.Module): # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = hidden_size // self.total_num_heads + # MixtralConfig has an optional head_dim argument + self.head_dim = getattr(config, "head_dim", + self.hidden_size // self.total_num_heads) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -200,6 +203,7 @@ class MixtralDecoderLayer(nn.Module): # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 10000) self.self_attn = MixtralAttention( + config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, max_position=config.max_position_embeddings, diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 8a893b6d..5be91f40 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -165,6 +165,7 @@ class MixtralAttention(nn.Module): def __init__( self, + config: MixtralConfig, hidden_size: int, num_heads: int, num_kv_heads: int, @@ -190,7 +191,9 @@ class MixtralAttention(nn.Module): # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = hidden_size // self.total_num_heads + # MixtralConfig has an optional head_dim argument + self.head_dim = getattr(config, "head_dim", + self.hidden_size // self.total_num_heads) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -252,6 +255,7 @@ class MixtralDecoderLayer(nn.Module): # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 10000) self.self_attn = MixtralAttention( + config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, max_position=config.max_position_embeddings,