[Bugfix] MLPSpeculator: Use ParallelLMHead in tie_weights=False case. (#6303)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell 2024-07-10 15:04:07 +02:00 committed by GitHub
parent e72ae80b06
commit c38eba3046
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -110,7 +110,7 @@ class MLPSpeculator(nn.Module):
])
self.head = nn.ModuleList([
nn.Linear(self.inner_dim, self.vocab_size, bias=False)
ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
for _ in range(self.max_speculative_tokens)
])
self.ln = nn.ModuleList([