[Bugfix] Enable Proper attention_bias
Usage in Llama Model Configuration (#3767)
Co-authored-by: roy <jasonailu87@gmail.com>
This commit is contained in:
parent
f46864d68d
commit
bc0c0192d1
@ -184,6 +184,10 @@ class LlamaDecoderLayer(nn.Module):
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
sliding_window = getattr(config, "sliding_window", None)
|
||||
# Support abacusai/Smaug-72B-v0.1 with attention_bias
|
||||
# Support internlm/internlm-7b with bias
|
||||
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
||||
config, "bias", False)
|
||||
self.self_attn = LlamaAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
@ -193,7 +197,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
linear_method=linear_method,
|
||||
bias=getattr(config, "bias", False),
|
||||
bias=attention_bias,
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
self.mlp = LlamaMLP(
|
||||
|
Loading…
x
Reference in New Issue
Block a user