[Bugfix] Adjust mllama to regional compilation (#15112)

Signed-off-by: Jan Kaniecki <jkaniecki@habana.ai>
This commit is contained in:
Jan Kaniecki 2025-03-19 15:57:25 +01:00 committed by GitHub
parent 6c5a3195db
commit 8363cd093d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1070,8 +1070,8 @@ class MllamaTextModel(nn.Module):
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds hidden_states = inputs_embeds
for decoder_layer in self.layers: for idx, decoder_layer in enumerate(self.layers):
if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer): if idx in self.cross_attention_layers:
if not skip_cross_attention: if not skip_cross_attention:
hidden_states = decoder_layer( hidden_states = decoder_layer(
hidden_states=hidden_states, hidden_states=hidden_states,
@ -1081,16 +1081,13 @@ class MllamaTextModel(nn.Module):
full_text_row_masked_out_mask= full_text_row_masked_out_mask=
full_text_row_masked_out_mask, full_text_row_masked_out_mask,
) )
elif isinstance(decoder_layer, LlamaDecoderLayer): else:
hidden_states, residual = decoder_layer( hidden_states, residual = decoder_layer(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
residual=None, residual=None,
) )
hidden_states = hidden_states + residual hidden_states = hidden_states + residual
else:
raise ValueError(
f"Unknown decoder layer type {type(decoder_layer)}")
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states