[Model] Changes to MLPSpeculator to support tie_weights and input_scale (#5965)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Joshua Rosenkranz <jmrosenk@us.ibm.com>
This commit is contained in:
Thomas Parnell 2024-07-02 01:40:02 +02:00 committed by GitHub
parent e373853e12
commit 54600709b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 79 additions and 23 deletions

View File

@ -13,6 +13,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import MLPSpeculatorConfig from vllm.transformers_utils.configs import MLPSpeculatorConfig
SQRT2 = 2**0.5
class MLPSpeculatorLayerNorm(nn.Module): class MLPSpeculatorLayerNorm(nn.Module):
""" """
@ -26,14 +28,19 @@ class MLPSpeculatorLayerNorm(nn.Module):
Safety term to prevent division by zero. Make sure the chosen value Safety term to prevent division by zero. Make sure the chosen value
fits in the range of your encoding scheme fits in the range of your encoding scheme
(i.e. fp16 requires eps >= 6e-8). (i.e. fp16 requires eps >= 6e-8).
elementwise_scale_and_shift : bool
Include a learned scaling and shift term after normalization.
""" """
def __init__( def __init__(
self, self,
normalized_shape, normalized_shape,
eps=1e-06, eps=1e-06,
elementwise_scale_and_shift=True,
): ):
super(MLPSpeculatorLayerNorm, self).__init__() super(MLPSpeculatorLayerNorm, self).__init__()
self.elementwise_scale_and_shift = elementwise_scale_and_shift
if self.elementwise_scale_and_shift:
self.weight = nn.Parameter(torch.empty(normalized_shape)) self.weight = nn.Parameter(torch.empty(normalized_shape))
self.bias = nn.Parameter(torch.empty(normalized_shape)) self.bias = nn.Parameter(torch.empty(normalized_shape))
self.eps = eps self.eps = eps
@ -42,6 +49,7 @@ class MLPSpeculatorLayerNorm(nn.Module):
xf = x xf = x
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps) xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
x = xf.type_as(x) x = xf.type_as(x)
if self.elementwise_scale_and_shift:
x = self.weight * x x = self.weight * x
x = x + self.bias x = x + self.bias
return x return x
@ -59,6 +67,34 @@ class MLPSpeculator(nn.Module):
self.max_speculative_tokens = config.num_lookahead_tokens self.max_speculative_tokens = config.num_lookahead_tokens
self.tie_weights = config.tie_weights
self.scale_input = config.scale_input
if self.tie_weights:
assert (
self.n_predict >
1), "You cannot tie weights between stages when only 1 exists"
embedding = VocabParallelEmbedding(
config.vocab_size,
self.inner_dim,
org_num_embeddings=config.vocab_size)
self.emb = nn.ModuleList([embedding] * self.max_speculative_tokens)
# the initial projection from the base model may
# have a different size, so that stays separate.
proj_first = nn.Linear(self.emb_dim, self.inner_dim, bias=False)
proj_tied = nn.Linear(self.inner_dim, self.inner_dim, bias=False)
self.proj = nn.ModuleList([proj_first] + [proj_tied] *
(self.max_speculative_tokens - 1))
head = nn.Linear(self.inner_dim, self.vocab_size, bias=False)
self.head = nn.ModuleList([head] * self.max_speculative_tokens)
ln = MLPSpeculatorLayerNorm(self.inner_dim,
elementwise_scale_and_shift=True)
self.ln = nn.ModuleList([ln] * self.max_speculative_tokens)
else:
self.emb = nn.ModuleList([ self.emb = nn.ModuleList([
VocabParallelEmbedding(config.vocab_size, VocabParallelEmbedding(config.vocab_size,
self.inner_dim, self.inner_dim,
@ -69,7 +105,8 @@ class MLPSpeculator(nn.Module):
self.proj = nn.ModuleList([ self.proj = nn.ModuleList([
nn.Linear((self.emb_dim if i == 0 else self.inner_dim), nn.Linear((self.emb_dim if i == 0 else self.inner_dim),
self.inner_dim, self.inner_dim,
bias=False) for i in range(self.max_speculative_tokens) bias=False)
for i in range(self.max_speculative_tokens)
]) ])
self.head = nn.ModuleList([ self.head = nn.ModuleList([
@ -77,9 +114,13 @@ class MLPSpeculator(nn.Module):
for _ in range(self.max_speculative_tokens) for _ in range(self.max_speculative_tokens)
]) ])
self.ln = nn.ModuleList([ self.ln = nn.ModuleList([
MLPSpeculatorLayerNorm(self.inner_dim) MLPSpeculatorLayerNorm(self.inner_dim,
elementwise_scale_and_shift=True)
for _ in range(self.max_speculative_tokens) for _ in range(self.max_speculative_tokens)
]) ])
if self.scale_input:
self.ln0 = MLPSpeculatorLayerNorm(
self.emb_dim, elementwise_scale_and_shift=False)
self.state_weight = 0.5**(0.5 / config.n_predict) self.state_weight = 0.5**(0.5 / config.n_predict)
self.emb_weight = math.sqrt( self.emb_weight = math.sqrt(
@ -105,6 +146,9 @@ class MLPSpeculator(nn.Module):
# b x 1 x d # b x 1 x d
previous_hidden_states = previous_hidden_states.unsqueeze(1) previous_hidden_states = previous_hidden_states.unsqueeze(1)
if self.scale_input:
previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2
# b x 1 # b x 1
last_tokens = input_ids.unsqueeze(1) last_tokens = input_ids.unsqueeze(1)

View File

@ -17,6 +17,8 @@ class MLPSpeculatorConfig(PretrainedConfig):
n_predict: int = 3, n_predict: int = 3,
top_k_tokens_per_head: Optional[List[int]] = None, top_k_tokens_per_head: Optional[List[int]] = None,
n_candidates: int = 5, n_candidates: int = 5,
tie_weights: bool = False,
scale_input: bool = False,
**kwargs): **kwargs):
""" """
Initialize an MLPSpeculatorConfig Initialize an MLPSpeculatorConfig
@ -38,6 +40,14 @@ class MLPSpeculatorConfig(PretrainedConfig):
NOTE: This parameter is currently unused. NOTE: This parameter is currently unused.
n_candidates: int n_candidates: int
number of child candidates to create per sequence number of child candidates to create per sequence
tie_weights: bool
If true, use a single set of weights for every model
head/stage after the first. The initial projection
from the base model may have a different size, so that
stays separate.
scale_input: bool
if True, will scale the initial hidden states from
the base model.
""" """
if top_k_tokens_per_head is None: if top_k_tokens_per_head is None:
top_k_tokens_per_head = [5, 4, 3] top_k_tokens_per_head = [5, 4, 3]
@ -49,5 +59,7 @@ class MLPSpeculatorConfig(PretrainedConfig):
self.top_k_tokens_per_head = top_k_tokens_per_head self.top_k_tokens_per_head = top_k_tokens_per_head
self.n_candidates = n_candidates self.n_candidates = n_candidates
self.num_lookahead_tokens = n_predict self.num_lookahead_tokens = n_predict
self.tie_weights = tie_weights
self.scale_input = scale_input
super().__init__(**kwargs) super().__init__(**kwargs)