diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 5fd12649..a6a95c8d 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -133,7 +133,8 @@ def extra_groups_for_head_shards(ngroups: int, tp_size: int): if ngroups % tp_size == 0: return 0 - return tp_size - ngroups % tp_size + # for n_groups == 1, this is exactly tp_size - n_groups + return tp_size - ngroups def mamba_v2_sharded_weight_loader( @@ -153,7 +154,7 @@ def mamba_v2_sharded_weight_loader( boundary, loaded_boundary = 0, 0 # - iterate over the shard specs - for full_dim, extra, ratio in shard_spec: + for full_dim, extra, duplicate_groups in shard_spec: # - full dim is the model dim (before TP). # - extra > 0, means there is expected overall increase # of dimensions. This is so because of replication. @@ -167,7 +168,12 @@ def mamba_v2_sharded_weight_loader( # - compute the rank into the loaded shard. # - if there is replication, different TP shards will # take from the same rank. - rank = tp_rank // ratio + if duplicate_groups: + # NOTE: currently we only support duplication + # in the case where num_groups == 1 + rank = 0 + else: + rank = tp_rank # - leftmost boundary index into loaded weight. loaded_skip = rank * shard_size @@ -233,12 +239,21 @@ class MambaMixer2(CustomOp): # - HOWEVER IF, world_size DOES NOT divide groups, then we need # to allocate extra space in the shard, such that groups # may be replicated to follow the head shard. + # - NOTE: currently for the world size DOES NOT divide groups + # case, we only support the case when n_groups == 1 self.tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() assert num_heads % self.tp_size == 0, \ "Tensor parallel world size must divide num heads." + + assert (n_groups % self.tp_size) == 0 or n_groups == 1, \ + ( + "If tensor parallel world size does not divide num_heads, " + "then num_groups must equal 1." + ) + self.ssm_state_size = ssm_state_size self.activation = activation @@ -284,11 +299,10 @@ class MambaMixer2(CustomOp): self.n_groups * self.ssm_state_size, # expected model size (self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned - self.num_heads // - n_groups, # ratio for mapping back to original group + n_groups == 1, # if there was only one group ) - intermediate_settings = (intermediate_size, 0, 1) - head_setings = (self.num_heads, 0, 1) + intermediate_settings = (intermediate_size, 0, False) + head_setings = (self.num_heads, 0, False) # - the weight already has a "weight_loader" attribute # which set_weight_attrs will raise if we do not