[Bugfix] Adjust mllama to regional compilation (#15112)

Signed-off-by: Jan Kaniecki <jkaniecki@habana.ai>
This commit is contained in:
Jan Kaniecki 2025-03-19 15:57:25 +01:00 committed by GitHub
parent 6c5a3195db
commit 8363cd093d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1070,8 +1070,8 @@ class MllamaTextModel(nn.Module):
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
for decoder_layer in self.layers:
if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer):
for idx, decoder_layer in enumerate(self.layers):
if idx in self.cross_attention_layers:
if not skip_cross_attention:
hidden_states = decoder_layer(
hidden_states=hidden_states,
@ -1081,16 +1081,13 @@ class MllamaTextModel(nn.Module):
full_text_row_masked_out_mask=
full_text_row_masked_out_mask,
)
elif isinstance(decoder_layer, LlamaDecoderLayer):
else:
hidden_states, residual = decoder_layer(
positions=positions,
hidden_states=hidden_states,
residual=None,
)
hidden_states = hidden_states + residual
else:
raise ValueError(
f"Unknown decoder layer type {type(decoder_layer)}")
hidden_states = self.norm(hidden_states)
return hidden_states
@ -1551,4 +1548,4 @@ def convert_dense_cross_attention_mask_to_tensor(
full_text_mask = ((mask != ninf).any(dim=-1).type_as(mask)[..., None])
mask *= full_text_mask
# (num_prompt_tokens, num_encoder_tokens)
return mask
return mask