[Bugfix] MLPSpeculator: Use ParallelLMHead in tie_weights=False case. (#6303)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell 2024-07-10 15:04:07 +02:00 committed by GitHub
parent e72ae80b06
commit c38eba3046
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -110,7 +110,7 @@ class MLPSpeculator(nn.Module):
]) ])
self.head = nn.ModuleList([ self.head = nn.ModuleList([
nn.Linear(self.inner_dim, self.vocab_size, bias=False) ParallelLMHead(self.vocab_size, self.inner_dim, bias=False)
for _ in range(self.max_speculative_tokens) for _ in range(self.max_speculative_tokens)
]) ])
self.ln = nn.ModuleList([ self.ln = nn.ModuleList([