126 lines
4.0 KiB
Python
126 lines
4.0 KiB
Python
![]() |
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
||
|
import unittest
|
||
|
from typing import Tuple
|
||
|
|
||
|
import pytest
|
||
|
import torch
|
||
|
|
||
|
from tests.utils import multi_gpu_test
|
||
|
from vllm.distributed.parallel_state import (init_distributed_environment,
|
||
|
initialize_model_parallel)
|
||
|
from vllm.model_executor.layers.mamba.mamba_mixer2 import Mixer2RMSNormGated
|
||
|
from vllm.platforms import current_platform
|
||
|
from vllm.utils import update_environment_variables
|
||
|
|
||
|
|
||
|
@multi_gpu_test(num_gpus=2)
|
||
|
@pytest.mark.parametrize("batch_size", [8])
|
||
|
@pytest.mark.parametrize("seq_len", [128])
|
||
|
@pytest.mark.parametrize(
|
||
|
"hidden_size_n_groups",
|
||
|
[
|
||
|
(64, 1),
|
||
|
(64, 2),
|
||
|
(64, 4), # hidden_size be divisible by num_gpus
|
||
|
(100, 5), # and n_groups must divide hidden_size
|
||
|
])
|
||
|
@pytest.mark.parametrize("dtype", [torch.float16])
|
||
|
def test_mixer2_gated_norm_multi_gpu(
|
||
|
batch_size: int,
|
||
|
seq_len: int,
|
||
|
hidden_size_n_groups: Tuple[int, int],
|
||
|
dtype: torch.dtype,
|
||
|
device: str = 'cuda',
|
||
|
):
|
||
|
hidden_size, n_groups = hidden_size_n_groups
|
||
|
num_processes = 2
|
||
|
|
||
|
def run_torch_spawn(fn, nprocs):
|
||
|
# need to use torch.mp.spawn otherwise will have problems with
|
||
|
# torch.distributed and cuda
|
||
|
torch.multiprocessing.spawn(fn,
|
||
|
args=(
|
||
|
num_processes,
|
||
|
batch_size,
|
||
|
seq_len,
|
||
|
hidden_size,
|
||
|
n_groups,
|
||
|
dtype,
|
||
|
device,
|
||
|
),
|
||
|
nprocs=nprocs)
|
||
|
|
||
|
run_torch_spawn(mixer2_gated_norm_tensor_parallel, 2)
|
||
|
|
||
|
|
||
|
def mixer2_gated_norm_tensor_parallel(
|
||
|
local_rank: int,
|
||
|
world_size: int,
|
||
|
batch_size: int,
|
||
|
seq_len: int,
|
||
|
hidden_size: int,
|
||
|
n_groups: int,
|
||
|
dtype: torch.dtype,
|
||
|
device: str,
|
||
|
):
|
||
|
current_platform.seed_everything(0)
|
||
|
|
||
|
device = torch.device(f"cuda:{local_rank}")
|
||
|
torch.cuda.set_device(device)
|
||
|
torch.set_default_device(device)
|
||
|
torch.set_default_dtype(dtype)
|
||
|
|
||
|
update_environment_variables({
|
||
|
'RANK': str(local_rank),
|
||
|
'LOCAL_RANK': str(local_rank),
|
||
|
'WORLD_SIZE': str(world_size),
|
||
|
'MASTER_ADDR': 'localhost',
|
||
|
'MASTER_PORT': '12345',
|
||
|
})
|
||
|
|
||
|
# initialize distributed
|
||
|
init_distributed_environment()
|
||
|
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||
|
|
||
|
# create random weights an inputs
|
||
|
weight = torch.rand((hidden_size, ), dtype=dtype, device=device)
|
||
|
hidden_states = torch.randn(batch_size, seq_len, hidden_size)
|
||
|
gate_states = torch.randn(batch_size, seq_len, hidden_size)
|
||
|
|
||
|
# create gated-norm with TP
|
||
|
mixer = Mixer2RMSNormGated(
|
||
|
full_hidden_size=hidden_size,
|
||
|
full_n_groups=n_groups,
|
||
|
)
|
||
|
mixer.weight.weight_loader(mixer.weight, weight) # load
|
||
|
|
||
|
# create gated-norm without TP to compute reference
|
||
|
# - utilize mock patching to disable TP when
|
||
|
with (unittest.mock.patch(
|
||
|
"vllm.model_executor.layers.mamba.mamba_mixer2."
|
||
|
"get_tensor_model_parallel_world_size",
|
||
|
return_value=1),
|
||
|
unittest.mock.patch(
|
||
|
"vllm.model_executor.layers.mamba.mamba_mixer2."
|
||
|
"get_tensor_model_parallel_rank",
|
||
|
return_value=0)):
|
||
|
mixer_single_gpu = Mixer2RMSNormGated(
|
||
|
full_hidden_size=hidden_size,
|
||
|
full_n_groups=n_groups,
|
||
|
)
|
||
|
# assign weight to single-gpu mixer
|
||
|
mixer_single_gpu.weight.data = weight
|
||
|
|
||
|
# generate and compare
|
||
|
N = hidden_size // world_size
|
||
|
output = mixer(
|
||
|
hidden_states[..., local_rank * N:(local_rank + 1) * N],
|
||
|
gate_states[..., local_rank * N:(local_rank + 1) * N],
|
||
|
)
|
||
|
ref_output = mixer_single_gpu(hidden_states, gate_states)
|
||
|
torch.allclose(output,
|
||
|
ref_output[..., local_rank * N:(local_rank + 1) * N],
|
||
|
atol=1e-3,
|
||
|
rtol=1e-3)
|