[Bugfix][Model] Mixtral: use unused head_dim config argument (#14961)

Signed-off-by: Quentin Torroba <quentin.torroba@mistral.ai>
This commit is contained in:
Quentin 2025-03-17 15:44:18 +01:00 committed by GitHub
parent e1eb45d397
commit aaaec52ad9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 2 deletions

View File

@ -111,6 +111,7 @@ class MixtralAttention(nn.Module):
def __init__( def __init__(
self, self,
config: MixtralConfig,
hidden_size: int, hidden_size: int,
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
@ -136,7 +137,9 @@ class MixtralAttention(nn.Module):
# the KV heads across multiple tensor parallel GPUs. # the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0 assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads # MixtralConfig has an optional head_dim argument
self.head_dim = getattr(config, "head_dim",
self.hidden_size // self.total_num_heads)
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
@ -200,6 +203,7 @@ class MixtralDecoderLayer(nn.Module):
# Requires transformers > 4.32.0 # Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = MixtralAttention( self.self_attn = MixtralAttention(
config=config,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings, max_position=config.max_position_embeddings,

View File

@ -165,6 +165,7 @@ class MixtralAttention(nn.Module):
def __init__( def __init__(
self, self,
config: MixtralConfig,
hidden_size: int, hidden_size: int,
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
@ -190,7 +191,9 @@ class MixtralAttention(nn.Module):
# the KV heads across multiple tensor parallel GPUs. # the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0 assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads # MixtralConfig has an optional head_dim argument
self.head_dim = getattr(config, "head_dim",
self.hidden_size // self.total_num_heads)
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
@ -252,6 +255,7 @@ class MixtralDecoderLayer(nn.Module):
# Requires transformers > 4.32.0 # Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = MixtralAttention( self.self_attn = MixtralAttention(
config=config,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings, max_position=config.max_position_embeddings,