Fused MOE for Mixtral (#2542)
Co-authored-by: chen shen <scv119@gmail.com>
This commit is contained in:
parent
5d60def02c
commit
ab40644669
@ -95,7 +95,7 @@ void moe_align_block_size(
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
assert(num_experts <= NUM_MAX_EXPERTS);
|
||||
VLLM_DISPATCH_INTEGRAL_TYPES(
|
||||
topk_ids.scalar_type(), "moe_alig_block_size_kernel", [&] {
|
||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||
vllm::moe_align_block_size_kernel<scalar_t><<<1, num_experts, 0, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
|
16
csrc/ops.h
16
csrc/ops.h
@ -100,6 +100,13 @@ void gptq_shuffle(
|
||||
torch::Tensor q_weight,
|
||||
torch::Tensor q_perm);
|
||||
|
||||
void moe_align_block_size(
|
||||
torch::Tensor topk_ids,
|
||||
int num_experts,
|
||||
int block_size,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
using fptr_t = uint64_t;
|
||||
@ -121,12 +128,3 @@ std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
|
||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
|
||||
const std::vector<std::vector<int64_t>> &offsets);
|
||||
#endif
|
||||
|
||||
void moe_align_block_size(
|
||||
torch::Tensor topk_ids,
|
||||
int num_experts,
|
||||
int block_size,
|
||||
torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
torch::Tensor num_tokens_post_pad
|
||||
);
|
||||
|
@ -57,9 +57,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
|
||||
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
||||
ops.def(
|
||||
"moe_align_block_size",
|
||||
&moe_align_block_size,
|
||||
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size.");
|
||||
"moe_align_block_size",
|
||||
&moe_align_block_size,
|
||||
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size.");
|
||||
|
||||
// Cache ops
|
||||
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
|
||||
|
@ -23,8 +23,6 @@
|
||||
"""Inference-only Mixtral model."""
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
@ -33,10 +31,11 @@ from transformers import MixtralConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
ReplicatedLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
@ -47,6 +46,7 @@ from vllm.model_executor.parallel_utils.communication_op import (
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.sequence import SamplerOutput
|
||||
@ -54,85 +54,77 @@ from vllm.sequence import SamplerOutput
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class MixtralMLP(nn.Module):
|
||||
class MixtralMoE(nn.Module):
|
||||
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
|
||||
across all ranks.
|
||||
|
||||
Each expert's weights are sharded across all ranks and a fused MoE
|
||||
kernel is used for the forward pass, and finally we reduce the outputs
|
||||
across ranks.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_experts = num_experts
|
||||
self.ffn_dim = intermediate_size
|
||||
self.hidden_dim = hidden_size
|
||||
|
||||
self.w1 = ReplicatedLinear(self.hidden_dim,
|
||||
self.ffn_dim,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
self.w2 = ReplicatedLinear(self.ffn_dim,
|
||||
self.hidden_dim,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
self.w3 = ReplicatedLinear(self.hidden_dim,
|
||||
self.ffn_dim,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
|
||||
# TODO: Use vllm's SiluAndMul
|
||||
self.act_fn = nn.SiLU()
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
w1_out, _ = self.w1(hidden_states)
|
||||
w1_out = self.act_fn(w1_out)
|
||||
w3_out, _ = self.w3(hidden_states)
|
||||
current_hidden_states = w1_out * w3_out
|
||||
current_hidden_states, _ = self.w2(current_hidden_states)
|
||||
return current_hidden_states
|
||||
|
||||
|
||||
class MixtralMoE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_total_experts = config.num_local_experts
|
||||
self.top_k = config.num_experts_per_tok
|
||||
if self.tp_size > self.num_total_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {self.num_total_experts}.")
|
||||
# Split experts equally between ranks
|
||||
self.expert_indicies = np.array_split(range(
|
||||
self.num_total_experts), self.tp_size)[self.rank].tolist()
|
||||
if not self.expert_indicies:
|
||||
raise ValueError(
|
||||
f"Rank {self.rank} has no experts assigned to it.")
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_total_experts = num_experts
|
||||
self.top_k = top_k
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size // tp_size
|
||||
|
||||
self.experts = nn.ModuleList([
|
||||
MixtralMLP(self.num_total_experts,
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
linear_method=linear_method)
|
||||
if idx in self.expert_indicies else None
|
||||
for idx in range(self.num_total_experts)
|
||||
])
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
|
||||
self.gate = ReplicatedLinear(self.hidden_size,
|
||||
self.num_total_experts,
|
||||
bias=False,
|
||||
params_dtype=self.params_dtype,
|
||||
linear_method=None)
|
||||
|
||||
self.ws = nn.Parameter(
|
||||
torch.empty(self.num_total_experts,
|
||||
2 * self.intermediate_size,
|
||||
self.hidden_size,
|
||||
device="cuda",
|
||||
dtype=self.params_dtype))
|
||||
self.w2s = nn.Parameter(
|
||||
torch.empty(self.num_total_experts,
|
||||
self.hidden_size,
|
||||
self.intermediate_size,
|
||||
device="cuda",
|
||||
dtype=self.params_dtype))
|
||||
|
||||
set_weight_attrs(self.ws, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
set_weight_attrs(self.w2s, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
|
||||
weight_name: str, expert_id: int):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
param_data = param.data
|
||||
shard_size = self.intermediate_size
|
||||
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
||||
if weight_name.endswith("w1.weight"):
|
||||
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
|
||||
if weight_name.endswith("w3.weight"):
|
||||
param_data[expert_id,
|
||||
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
|
||||
if weight_name.endswith("w2.weight"):
|
||||
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
batch_size, sequence_length, hidden_size = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
@ -142,22 +134,18 @@ class MixtralMoE(nn.Module):
|
||||
dim=-1)
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
final_hidden_states = None
|
||||
for expert_idx in self.expert_indicies:
|
||||
expert_layer = self.experts[expert_idx]
|
||||
expert_mask = (selected_experts == expert_idx)
|
||||
expert_weights = (routing_weights * expert_mask).sum(dim=-1,
|
||||
keepdim=True)
|
||||
final_hidden_states = fused_moe(hidden_states,
|
||||
self.ws,
|
||||
self.w2s,
|
||||
routing_weights,
|
||||
selected_experts,
|
||||
inplace=True)
|
||||
|
||||
current_hidden_states = expert_layer(hidden_states).mul_(
|
||||
expert_weights)
|
||||
if final_hidden_states is None:
|
||||
final_hidden_states = current_hidden_states
|
||||
else:
|
||||
final_hidden_states.add_(current_hidden_states)
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
return tensor_model_parallel_all_reduce(final_hidden_states).view(
|
||||
batch_size, sequence_length, hidden_dim)
|
||||
return final_hidden_states.view(batch_size, sequence_length,
|
||||
hidden_size)
|
||||
|
||||
|
||||
class MixtralAttention(nn.Module):
|
||||
@ -257,8 +245,11 @@ class MixtralDecoderLayer(nn.Module):
|
||||
rope_theta=rope_theta,
|
||||
sliding_window=config.sliding_window,
|
||||
linear_method=linear_method)
|
||||
self.block_sparse_moe = MixtralMoE(config=config,
|
||||
linear_method=linear_method)
|
||||
self.block_sparse_moe = MixtralMoE(
|
||||
num_experts=config.num_local_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
@ -378,6 +369,14 @@ class MixtralForCausalLM(nn.Module):
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
|
||||
expert_params_mapping = [
|
||||
# (param_name, weight_name, expert_id)
|
||||
("ws" if weight_name in ["w1", "w3"] else "w2s",
|
||||
f"experts.{expert_id}.{weight_name}.weight", expert_id)
|
||||
for expert_id in range(self.config.num_local_experts)
|
||||
for weight_name in ["w1", "w2", "w3"]
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path,
|
||||
@ -387,6 +386,7 @@ class MixtralForCausalLM(nn.Module):
|
||||
fall_back_to_pt=False):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
@ -399,14 +399,22 @@ class MixtralForCausalLM(nn.Module):
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Skip experts that are not assigned to this worker.
|
||||
if ("block_sparse_moe.experts." in name
|
||||
and name not in params_dict):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
for param_name, weight_name, expert_id in expert_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
weight_name,
|
||||
expert_id=expert_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
Loading…
x
Reference in New Issue
Block a user