[Bugfix] Adjust mllama to regional compilation (#15112)
Signed-off-by: Jan Kaniecki <jkaniecki@habana.ai>
This commit is contained in:
parent
6c5a3195db
commit
8363cd093d
@ -1070,8 +1070,8 @@ class MllamaTextModel(nn.Module):
|
|||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
for decoder_layer in self.layers:
|
for idx, decoder_layer in enumerate(self.layers):
|
||||||
if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer):
|
if idx in self.cross_attention_layers:
|
||||||
if not skip_cross_attention:
|
if not skip_cross_attention:
|
||||||
hidden_states = decoder_layer(
|
hidden_states = decoder_layer(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
@ -1081,16 +1081,13 @@ class MllamaTextModel(nn.Module):
|
|||||||
full_text_row_masked_out_mask=
|
full_text_row_masked_out_mask=
|
||||||
full_text_row_masked_out_mask,
|
full_text_row_masked_out_mask,
|
||||||
)
|
)
|
||||||
elif isinstance(decoder_layer, LlamaDecoderLayer):
|
else:
|
||||||
hidden_states, residual = decoder_layer(
|
hidden_states, residual = decoder_layer(
|
||||||
positions=positions,
|
positions=positions,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
residual=None,
|
residual=None,
|
||||||
)
|
)
|
||||||
hidden_states = hidden_states + residual
|
hidden_states = hidden_states + residual
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown decoder layer type {type(decoder_layer)}")
|
|
||||||
hidden_states = self.norm(hidden_states)
|
hidden_states = self.norm(hidden_states)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user