[Model] Changes to MLPSpeculator to support tie_weights and input_scale (#5965)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Joshua Rosenkranz <jmrosenk@us.ibm.com>
This commit is contained in:
parent
e373853e12
commit
54600709b6
@ -13,6 +13,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.sequence import SamplerOutput
|
from vllm.sequence import SamplerOutput
|
||||||
from vllm.transformers_utils.configs import MLPSpeculatorConfig
|
from vllm.transformers_utils.configs import MLPSpeculatorConfig
|
||||||
|
|
||||||
|
SQRT2 = 2**0.5
|
||||||
|
|
||||||
|
|
||||||
class MLPSpeculatorLayerNorm(nn.Module):
|
class MLPSpeculatorLayerNorm(nn.Module):
|
||||||
"""
|
"""
|
||||||
@ -26,14 +28,19 @@ class MLPSpeculatorLayerNorm(nn.Module):
|
|||||||
Safety term to prevent division by zero. Make sure the chosen value
|
Safety term to prevent division by zero. Make sure the chosen value
|
||||||
fits in the range of your encoding scheme
|
fits in the range of your encoding scheme
|
||||||
(i.e. fp16 requires eps >= 6e-8).
|
(i.e. fp16 requires eps >= 6e-8).
|
||||||
|
elementwise_scale_and_shift : bool
|
||||||
|
Include a learned scaling and shift term after normalization.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
normalized_shape,
|
normalized_shape,
|
||||||
eps=1e-06,
|
eps=1e-06,
|
||||||
|
elementwise_scale_and_shift=True,
|
||||||
):
|
):
|
||||||
super(MLPSpeculatorLayerNorm, self).__init__()
|
super(MLPSpeculatorLayerNorm, self).__init__()
|
||||||
|
self.elementwise_scale_and_shift = elementwise_scale_and_shift
|
||||||
|
if self.elementwise_scale_and_shift:
|
||||||
self.weight = nn.Parameter(torch.empty(normalized_shape))
|
self.weight = nn.Parameter(torch.empty(normalized_shape))
|
||||||
self.bias = nn.Parameter(torch.empty(normalized_shape))
|
self.bias = nn.Parameter(torch.empty(normalized_shape))
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
@ -42,6 +49,7 @@ class MLPSpeculatorLayerNorm(nn.Module):
|
|||||||
xf = x
|
xf = x
|
||||||
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
|
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
x = xf.type_as(x)
|
x = xf.type_as(x)
|
||||||
|
if self.elementwise_scale_and_shift:
|
||||||
x = self.weight * x
|
x = self.weight * x
|
||||||
x = x + self.bias
|
x = x + self.bias
|
||||||
return x
|
return x
|
||||||
@ -59,6 +67,34 @@ class MLPSpeculator(nn.Module):
|
|||||||
|
|
||||||
self.max_speculative_tokens = config.num_lookahead_tokens
|
self.max_speculative_tokens = config.num_lookahead_tokens
|
||||||
|
|
||||||
|
self.tie_weights = config.tie_weights
|
||||||
|
self.scale_input = config.scale_input
|
||||||
|
|
||||||
|
if self.tie_weights:
|
||||||
|
assert (
|
||||||
|
self.n_predict >
|
||||||
|
1), "You cannot tie weights between stages when only 1 exists"
|
||||||
|
embedding = VocabParallelEmbedding(
|
||||||
|
config.vocab_size,
|
||||||
|
self.inner_dim,
|
||||||
|
org_num_embeddings=config.vocab_size)
|
||||||
|
self.emb = nn.ModuleList([embedding] * self.max_speculative_tokens)
|
||||||
|
|
||||||
|
# the initial projection from the base model may
|
||||||
|
# have a different size, so that stays separate.
|
||||||
|
proj_first = nn.Linear(self.emb_dim, self.inner_dim, bias=False)
|
||||||
|
proj_tied = nn.Linear(self.inner_dim, self.inner_dim, bias=False)
|
||||||
|
self.proj = nn.ModuleList([proj_first] + [proj_tied] *
|
||||||
|
(self.max_speculative_tokens - 1))
|
||||||
|
|
||||||
|
head = nn.Linear(self.inner_dim, self.vocab_size, bias=False)
|
||||||
|
self.head = nn.ModuleList([head] * self.max_speculative_tokens)
|
||||||
|
|
||||||
|
ln = MLPSpeculatorLayerNorm(self.inner_dim,
|
||||||
|
elementwise_scale_and_shift=True)
|
||||||
|
self.ln = nn.ModuleList([ln] * self.max_speculative_tokens)
|
||||||
|
|
||||||
|
else:
|
||||||
self.emb = nn.ModuleList([
|
self.emb = nn.ModuleList([
|
||||||
VocabParallelEmbedding(config.vocab_size,
|
VocabParallelEmbedding(config.vocab_size,
|
||||||
self.inner_dim,
|
self.inner_dim,
|
||||||
@ -69,7 +105,8 @@ class MLPSpeculator(nn.Module):
|
|||||||
self.proj = nn.ModuleList([
|
self.proj = nn.ModuleList([
|
||||||
nn.Linear((self.emb_dim if i == 0 else self.inner_dim),
|
nn.Linear((self.emb_dim if i == 0 else self.inner_dim),
|
||||||
self.inner_dim,
|
self.inner_dim,
|
||||||
bias=False) for i in range(self.max_speculative_tokens)
|
bias=False)
|
||||||
|
for i in range(self.max_speculative_tokens)
|
||||||
])
|
])
|
||||||
|
|
||||||
self.head = nn.ModuleList([
|
self.head = nn.ModuleList([
|
||||||
@ -77,9 +114,13 @@ class MLPSpeculator(nn.Module):
|
|||||||
for _ in range(self.max_speculative_tokens)
|
for _ in range(self.max_speculative_tokens)
|
||||||
])
|
])
|
||||||
self.ln = nn.ModuleList([
|
self.ln = nn.ModuleList([
|
||||||
MLPSpeculatorLayerNorm(self.inner_dim)
|
MLPSpeculatorLayerNorm(self.inner_dim,
|
||||||
|
elementwise_scale_and_shift=True)
|
||||||
for _ in range(self.max_speculative_tokens)
|
for _ in range(self.max_speculative_tokens)
|
||||||
])
|
])
|
||||||
|
if self.scale_input:
|
||||||
|
self.ln0 = MLPSpeculatorLayerNorm(
|
||||||
|
self.emb_dim, elementwise_scale_and_shift=False)
|
||||||
|
|
||||||
self.state_weight = 0.5**(0.5 / config.n_predict)
|
self.state_weight = 0.5**(0.5 / config.n_predict)
|
||||||
self.emb_weight = math.sqrt(
|
self.emb_weight = math.sqrt(
|
||||||
@ -105,6 +146,9 @@ class MLPSpeculator(nn.Module):
|
|||||||
# b x 1 x d
|
# b x 1 x d
|
||||||
previous_hidden_states = previous_hidden_states.unsqueeze(1)
|
previous_hidden_states = previous_hidden_states.unsqueeze(1)
|
||||||
|
|
||||||
|
if self.scale_input:
|
||||||
|
previous_hidden_states = self.ln0(previous_hidden_states) / SQRT2
|
||||||
|
|
||||||
# b x 1
|
# b x 1
|
||||||
last_tokens = input_ids.unsqueeze(1)
|
last_tokens = input_ids.unsqueeze(1)
|
||||||
|
|
||||||
|
@ -17,6 +17,8 @@ class MLPSpeculatorConfig(PretrainedConfig):
|
|||||||
n_predict: int = 3,
|
n_predict: int = 3,
|
||||||
top_k_tokens_per_head: Optional[List[int]] = None,
|
top_k_tokens_per_head: Optional[List[int]] = None,
|
||||||
n_candidates: int = 5,
|
n_candidates: int = 5,
|
||||||
|
tie_weights: bool = False,
|
||||||
|
scale_input: bool = False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
Initialize an MLPSpeculatorConfig
|
Initialize an MLPSpeculatorConfig
|
||||||
@ -38,6 +40,14 @@ class MLPSpeculatorConfig(PretrainedConfig):
|
|||||||
NOTE: This parameter is currently unused.
|
NOTE: This parameter is currently unused.
|
||||||
n_candidates: int
|
n_candidates: int
|
||||||
number of child candidates to create per sequence
|
number of child candidates to create per sequence
|
||||||
|
tie_weights: bool
|
||||||
|
If true, use a single set of weights for every model
|
||||||
|
head/stage after the first. The initial projection
|
||||||
|
from the base model may have a different size, so that
|
||||||
|
stays separate.
|
||||||
|
scale_input: bool
|
||||||
|
if True, will scale the initial hidden states from
|
||||||
|
the base model.
|
||||||
"""
|
"""
|
||||||
if top_k_tokens_per_head is None:
|
if top_k_tokens_per_head is None:
|
||||||
top_k_tokens_per_head = [5, 4, 3]
|
top_k_tokens_per_head = [5, 4, 3]
|
||||||
@ -49,5 +59,7 @@ class MLPSpeculatorConfig(PretrainedConfig):
|
|||||||
self.top_k_tokens_per_head = top_k_tokens_per_head
|
self.top_k_tokens_per_head = top_k_tokens_per_head
|
||||||
self.n_candidates = n_candidates
|
self.n_candidates = n_candidates
|
||||||
self.num_lookahead_tokens = n_predict
|
self.num_lookahead_tokens = n_predict
|
||||||
|
self.tie_weights = tie_weights
|
||||||
|
self.scale_input = scale_input
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user