Fused MOE for Mixtral (#2542)

Co-authored-by: chen shen <scv119@gmail.com>
This commit is contained in:
Philipp Moritz 2024-01-29 22:43:37 -08:00 committed by GitHub
parent 5d60def02c
commit ab40644669
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 114 additions and 108 deletions

View File

@ -95,7 +95,7 @@ void moe_align_block_size(
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
assert(num_experts <= NUM_MAX_EXPERTS);
VLLM_DISPATCH_INTEGRAL_TYPES(
topk_ids.scalar_type(), "moe_alig_block_size_kernel", [&] {
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
vllm::moe_align_block_size_kernel<scalar_t><<<1, num_experts, 0, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),

View File

@ -100,6 +100,13 @@ void gptq_shuffle(
torch::Tensor q_weight,
torch::Tensor q_perm);
void moe_align_block_size(
torch::Tensor topk_ids,
int num_experts,
int block_size,
torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad);
#ifndef USE_ROCM
using fptr_t = uint64_t;
@ -121,12 +128,3 @@ std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
const std::vector<std::vector<int64_t>> &offsets);
#endif
void moe_align_block_size(
torch::Tensor topk_ids,
int num_experts,
int block_size,
torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad
);

View File

@ -57,9 +57,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
ops.def(
"moe_align_block_size",
&moe_align_block_size,
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size.");
"moe_align_block_size",
&moe_align_block_size,
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size.");
// Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");

View File

@ -23,8 +23,6 @@
"""Inference-only Mixtral model."""
from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
@ -33,10 +31,11 @@ from transformers import MixtralConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
ReplicatedLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
@ -47,6 +46,7 @@ from vllm.model_executor.parallel_utils.communication_op import (
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
@ -54,85 +54,77 @@ from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class MixtralMLP(nn.Module):
class MixtralMoE(nn.Module):
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
across all ranks.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.num_experts = num_experts
self.ffn_dim = intermediate_size
self.hidden_dim = hidden_size
self.w1 = ReplicatedLinear(self.hidden_dim,
self.ffn_dim,
bias=False,
linear_method=linear_method)
self.w2 = ReplicatedLinear(self.ffn_dim,
self.hidden_dim,
bias=False,
linear_method=linear_method)
self.w3 = ReplicatedLinear(self.hidden_dim,
self.ffn_dim,
bias=False,
linear_method=linear_method)
# TODO: Use vllm's SiluAndMul
self.act_fn = nn.SiLU()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
w1_out, _ = self.w1(hidden_states)
w1_out = self.act_fn(w1_out)
w3_out, _ = self.w3(hidden_states)
current_hidden_states = w1_out * w3_out
current_hidden_states, _ = self.w2(current_hidden_states)
return current_hidden_states
class MixtralMoE(nn.Module):
def __init__(
self,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
params_dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.config = config
self.rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.num_total_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
if self.tp_size > self.num_total_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {self.num_total_experts}.")
# Split experts equally between ranks
self.expert_indicies = np.array_split(range(
self.num_total_experts), self.tp_size)[self.rank].tolist()
if not self.expert_indicies:
raise ValueError(
f"Rank {self.rank} has no experts assigned to it.")
tp_size = get_tensor_model_parallel_world_size()
self.num_total_experts = num_experts
self.top_k = top_k
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size // tp_size
self.experts = nn.ModuleList([
MixtralMLP(self.num_total_experts,
config.hidden_size,
config.intermediate_size,
linear_method=linear_method)
if idx in self.expert_indicies else None
for idx in range(self.num_total_experts)
])
self.gate = ReplicatedLinear(config.hidden_size,
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
self.gate = ReplicatedLinear(self.hidden_size,
self.num_total_experts,
bias=False,
params_dtype=self.params_dtype,
linear_method=None)
self.ws = nn.Parameter(
torch.empty(self.num_total_experts,
2 * self.intermediate_size,
self.hidden_size,
device="cuda",
dtype=self.params_dtype))
self.w2s = nn.Parameter(
torch.empty(self.num_total_experts,
self.hidden_size,
self.intermediate_size,
device="cuda",
dtype=self.params_dtype))
set_weight_attrs(self.ws, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.w2s, {
"weight_loader": self.weight_loader,
})
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
weight_name: str, expert_id: int):
tp_rank = get_tensor_model_parallel_rank()
param_data = param.data
shard_size = self.intermediate_size
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
if weight_name.endswith("w1.weight"):
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w3.weight"):
param_data[expert_id,
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard]
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
batch_size, sequence_length, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (batch * sequence_length, n_experts)
router_logits, _ = self.gate(hidden_states)
@ -142,22 +134,18 @@ class MixtralMoE(nn.Module):
dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
final_hidden_states = None
for expert_idx in self.expert_indicies:
expert_layer = self.experts[expert_idx]
expert_mask = (selected_experts == expert_idx)
expert_weights = (routing_weights * expert_mask).sum(dim=-1,
keepdim=True)
final_hidden_states = fused_moe(hidden_states,
self.ws,
self.w2s,
routing_weights,
selected_experts,
inplace=True)
current_hidden_states = expert_layer(hidden_states).mul_(
expert_weights)
if final_hidden_states is None:
final_hidden_states = current_hidden_states
else:
final_hidden_states.add_(current_hidden_states)
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return tensor_model_parallel_all_reduce(final_hidden_states).view(
batch_size, sequence_length, hidden_dim)
return final_hidden_states.view(batch_size, sequence_length,
hidden_size)
class MixtralAttention(nn.Module):
@ -257,8 +245,11 @@ class MixtralDecoderLayer(nn.Module):
rope_theta=rope_theta,
sliding_window=config.sliding_window,
linear_method=linear_method)
self.block_sparse_moe = MixtralMoE(config=config,
linear_method=linear_method)
self.block_sparse_moe = MixtralMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
@ -378,6 +369,14 @@ class MixtralForCausalLM(nn.Module):
("qkv_proj", "v_proj", "v"),
]
expert_params_mapping = [
# (param_name, weight_name, expert_id)
("ws" if weight_name in ["w1", "w3"] else "w2s",
f"experts.{expert_id}.{weight_name}.weight", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path,
@ -387,6 +386,7 @@ class MixtralForCausalLM(nn.Module):
fall_back_to_pt=False):
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
@ -399,14 +399,22 @@ class MixtralForCausalLM(nn.Module):
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if ("block_sparse_moe.experts." in name
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
for param_name, weight_name, expert_id in expert_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
weight_name,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)