Correction to TP logic for Mamba Mixer 2 when Num Groups not divisible by TP Size (#13660)
This commit is contained in:
parent
da31b5333e
commit
fca20841c2
@ -133,7 +133,8 @@ def extra_groups_for_head_shards(ngroups: int, tp_size: int):
|
||||
if ngroups % tp_size == 0:
|
||||
return 0
|
||||
|
||||
return tp_size - ngroups % tp_size
|
||||
# for n_groups == 1, this is exactly tp_size - n_groups
|
||||
return tp_size - ngroups
|
||||
|
||||
|
||||
def mamba_v2_sharded_weight_loader(
|
||||
@ -153,7 +154,7 @@ def mamba_v2_sharded_weight_loader(
|
||||
boundary, loaded_boundary = 0, 0
|
||||
|
||||
# - iterate over the shard specs
|
||||
for full_dim, extra, ratio in shard_spec:
|
||||
for full_dim, extra, duplicate_groups in shard_spec:
|
||||
# - full dim is the model dim (before TP).
|
||||
# - extra > 0, means there is expected overall increase
|
||||
# of dimensions. This is so because of replication.
|
||||
@ -167,7 +168,12 @@ def mamba_v2_sharded_weight_loader(
|
||||
# - compute the rank into the loaded shard.
|
||||
# - if there is replication, different TP shards will
|
||||
# take from the same rank.
|
||||
rank = tp_rank // ratio
|
||||
if duplicate_groups:
|
||||
# NOTE: currently we only support duplication
|
||||
# in the case where num_groups == 1
|
||||
rank = 0
|
||||
else:
|
||||
rank = tp_rank
|
||||
|
||||
# - leftmost boundary index into loaded weight.
|
||||
loaded_skip = rank * shard_size
|
||||
@ -233,12 +239,21 @@ class MambaMixer2(CustomOp):
|
||||
# - HOWEVER IF, world_size DOES NOT divide groups, then we need
|
||||
# to allocate extra space in the shard, such that groups
|
||||
# may be replicated to follow the head shard.
|
||||
# - NOTE: currently for the world size DOES NOT divide groups
|
||||
# case, we only support the case when n_groups == 1
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
assert num_heads % self.tp_size == 0, \
|
||||
"Tensor parallel world size must divide num heads."
|
||||
|
||||
|
||||
assert (n_groups % self.tp_size) == 0 or n_groups == 1, \
|
||||
(
|
||||
"If tensor parallel world size does not divide num_heads, "
|
||||
"then num_groups must equal 1."
|
||||
)
|
||||
|
||||
self.ssm_state_size = ssm_state_size
|
||||
self.activation = activation
|
||||
|
||||
@ -284,11 +299,10 @@ class MambaMixer2(CustomOp):
|
||||
self.n_groups * self.ssm_state_size, # expected model size
|
||||
(self.n_groups - n_groups) *
|
||||
self.ssm_state_size, # extra dims assigned
|
||||
self.num_heads //
|
||||
n_groups, # ratio for mapping back to original group
|
||||
n_groups == 1, # if there was only one group
|
||||
)
|
||||
intermediate_settings = (intermediate_size, 0, 1)
|
||||
head_setings = (self.num_heads, 0, 1)
|
||||
intermediate_settings = (intermediate_size, 0, False)
|
||||
head_setings = (self.num_heads, 0, False)
|
||||
|
||||
# - the weight already has a "weight_loader" attribute
|
||||
# which set_weight_attrs will raise if we do not
|
||||
|
Loading…
x
Reference in New Issue
Block a user