[Bugfix] MLPSpeculator: Use ParallelLMHead in tie_weights=False case. (#6303)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
parent
e72ae80b06
commit
c38eba3046
@ -110,7 +110,7 @@ class MLPSpeculator(nn.Module):
|
||||
])
|
||||
|
||||
self.head = nn.ModuleList([
|
||||
nn.Linear(self.inner_dim, self.vocab_size, bias=False)
|
||||
ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
|
||||
for _ in range(self.max_speculative_tokens)
|
||||
])
|
||||
self.ln = nn.ModuleList([
|
||||
|
Loading…
x
Reference in New Issue
Block a user