[Model] Changes to MLPSpeculator to support tie_weights and input_scale (#5965)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Joshua Rosenkranz <jmrosenk@us.ibm.com>
This commit is contained in:
parent
e373853e12
commit
54600709b6
@ -13,6 +13,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs import MLPSpeculatorConfig
|
||||
|
||||
SQRT2 = 2**0.5
|
||||
|
||||
|
||||
class MLPSpeculatorLayerNorm(nn.Module):
|
||||
"""
|
||||
@ -26,24 +28,30 @@ class MLPSpeculatorLayerNorm(nn.Module):
|
||||
Safety term to prevent division by zero. Make sure the chosen value
|
||||
fits in the range of your encoding scheme
|
||||
(i.e. fp16 requires eps >= 6e-8).
|
||||
elementwise_scale_and_shift : bool
|
||||
Include a learned scaling and shift term after normalization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
normalized_shape,
|
||||
eps=1e-06,
|
||||
elementwise_scale_and_shift=True,
|
||||
):
|
||||
super(MLPSpeculatorLayerNorm, self).__init__()
|
||||
self.weight = nn.Parameter(torch.empty(normalized_shape))
|
||||
self.bias = nn.Parameter(torch.empty(normalized_shape))
|
||||
self.elementwise_scale_and_shift = elementwise_scale_and_shift
|
||||
if self.elementwise_scale_and_shift:
|
||||
self.weight = nn.Parameter(torch.empty(normalized_shape))
|
||||
self.bias = nn.Parameter(torch.empty(normalized_shape))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
xf = x
|
||||
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
x = xf.type_as(x)
|
||||
x = self.weight * x
|
||||
x = x + self.bias
|
||||
if self.elementwise_scale_and_shift:
|
||||
x = self.weight * x
|
||||
x = x + self.bias
|
||||
return x
|
||||
|
||||
|
||||
@ -59,27 +67,60 @@ class MLPSpeculator(nn.Module):
|
||||
|
||||
self.max_speculative_tokens = config.num_lookahead_tokens
|
||||
|
||||
self.emb = nn.ModuleList([
|
||||
VocabParallelEmbedding(config.vocab_size,
|
||||
self.inner_dim,
|
||||
org_num_embeddings=config.vocab_size)
|
||||
for _ in range(self.max_speculative_tokens)
|
||||
])
|
||||
self.tie_weights = config.tie_weights
|
||||
self.scale_input = config.scale_input
|
||||
|
||||
self.proj = nn.ModuleList([
|
||||
nn.Linear((self.emb_dim if i == 0 else self.inner_dim),
|
||||
self.inner_dim,
|
||||
bias=False) for i in range(self.max_speculative_tokens)
|
||||
])
|
||||
if self.tie_weights:
|
||||
assert (
|
||||
self.n_predict >
|
||||
1), "You cannot tie weights between stages when only 1 exists"
|
||||
embedding = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
self.inner_dim,
|
||||
org_num_embeddings=config.vocab_size)
|
||||
self.emb = nn.ModuleList([embedding] * self.max_speculative_tokens)
|
||||
|
||||
self.head = nn.ModuleList([
|
||||
nn.Linear(self.inner_dim, self.vocab_size, bias=False)
|
||||
for _ in range(self.max_speculative_tokens)
|
||||
])
|
||||
self.ln = nn.ModuleList([
|
||||
MLPSpeculatorLayerNorm(self.inner_dim)
|
||||
for _ in range(self.max_speculative_tokens)
|
||||
])
|
||||
# the initial projection from the base model may
|
||||
# have a different size, so that stays separate.
|
||||
proj_first = nn.Linear(self.emb_dim, self.inner_dim, bias=False)
|
||||
proj_tied = nn.Linear(self.inner_dim, self.inner_dim, bias=False)
|
||||
self.proj = nn.ModuleList([proj_first] + [proj_tied] *
|
||||
(self.max_speculative_tokens - 1))
|
||||
|
||||
head = nn.Linear(self.inner_dim, self.vocab_size, bias=False)
|
||||
self.head = nn.ModuleList([head] * self.max_speculative_tokens)
|
||||
|
||||
ln = MLPSpeculatorLayerNorm(self.inner_dim,
|
||||
elementwise_scale_and_shift=True)
|
||||
self.ln = nn.ModuleList([ln] * self.max_speculative_tokens)
|
||||
|
||||
else:
|
||||
self.emb = nn.ModuleList([
|
||||
VocabParallelEmbedding(config.vocab_size,
|
||||
self.inner_dim,
|
||||
org_num_embeddings=config.vocab_size)
|
||||
for _ in range(self.max_speculative_tokens)
|
||||
])
|
||||
|
||||
self.proj = nn.ModuleList([
|
||||
nn.Linear((self.emb_dim if i == 0 else self.inner_dim),
|
||||
self.inner_dim,
|
||||
bias=False)
|
||||
for i in range(self.max_speculative_tokens)
|
||||
])
|
||||
|
||||
self.head = nn.ModuleList([
|
||||
nn.Linear(self.inner_dim, self.vocab_size, bias=False)
|
||||
for _ in range(self.max_speculative_tokens)
|
||||
])
|
||||
self.ln = nn.ModuleList([
|
||||
MLPSpeculatorLayerNorm(self.inner_dim,
|
||||
elementwise_scale_and_shift=True)
|
||||
for _ in range(self.max_speculative_tokens)
|
||||
])
|
||||
if self.scale_input:
|
||||
self.ln0 = MLPSpeculatorLayerNorm(
|
||||
self.emb_dim, elementwise_scale_and_shift=False)
|
||||
|
||||
self.state_weight = 0.5**(0.5 / config.n_predict)
|
||||
self.emb_weight = math.sqrt(
|
||||
@ -105,6 +146,9 @@ class MLPSpeculator(nn.Module):
|
||||
# b x 1 x d
|
||||
previous_hidden_states = previous_hidden_states.unsqueeze(1)
|
||||
|
||||
if self.scale_input:
|
||||
previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2
|
||||
|
||||
# b x 1
|
||||
last_tokens = input_ids.unsqueeze(1)
|
||||
|
||||
|
@ -17,6 +17,8 @@ class MLPSpeculatorConfig(PretrainedConfig):
|
||||
n_predict: int = 3,
|
||||
top_k_tokens_per_head: Optional[List[int]] = None,
|
||||
n_candidates: int = 5,
|
||||
tie_weights: bool = False,
|
||||
scale_input: bool = False,
|
||||
**kwargs):
|
||||
"""
|
||||
Initialize an MLPSpeculatorConfig
|
||||
@ -38,6 +40,14 @@ class MLPSpeculatorConfig(PretrainedConfig):
|
||||
NOTE: This parameter is currently unused.
|
||||
n_candidates: int
|
||||
number of child candidates to create per sequence
|
||||
tie_weights: bool
|
||||
If true, use a single set of weights for every model
|
||||
head/stage after the first. The initial projection
|
||||
from the base model may have a different size, so that
|
||||
stays separate.
|
||||
scale_input: bool
|
||||
if True, will scale the initial hidden states from
|
||||
the base model.
|
||||
"""
|
||||
if top_k_tokens_per_head is None:
|
||||
top_k_tokens_per_head = [5, 4, 3]
|
||||
@ -49,5 +59,7 @@ class MLPSpeculatorConfig(PretrainedConfig):
|
||||
self.top_k_tokens_per_head = top_k_tokens_per_head
|
||||
self.n_candidates = n_candidates
|
||||
self.num_lookahead_tokens = n_predict
|
||||
self.tie_weights = tie_weights
|
||||
self.scale_input = scale_input
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
Loading…
x
Reference in New Issue
Block a user