[Bugfix] MLPSpeculator: Use ParallelLMHead in tie_weights=False case. (#6303)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
parent
e72ae80b06
commit
c38eba3046
@ -110,7 +110,7 @@ class MLPSpeculator(nn.Module):
|
|||||||
])
|
])
|
||||||
|
|
||||||
self.head = nn.ModuleList([
|
self.head = nn.ModuleList([
|
||||||
nn.Linear(self.inner_dim, self.vocab_size, bias=False)
|
ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
|
||||||
for _ in range(self.max_speculative_tokens)
|
for _ in range(self.max_speculative_tokens)
|
||||||
])
|
])
|
||||||
self.ln = nn.ModuleList([
|
self.ln = nn.ModuleList([
|
||||||
|
Loading…
x
Reference in New Issue
Block a user