[Kernel] Support MoE Fp8 Checkpoints for Mixtral (Static Weights with Dynamic/Static Activations) (#4527)

Follow on to #4332 to enable FP8 checkpoint loading for Mixtral and supersedes #4436.

This PR enables the following checkpoint loading features for Mixtral:

Supports loading fp8 checkpoints for Mixtral, such as this "nm-testing/Mixtral-8x7B-Instruct-v0.1-FP8" test model
Supports static or dynamic activation quantization with static weight quantization (all per tensor)
Supports different scales for each expert weight
Supports Fp8 in QKV layer
Notes:

The Expert Gate/Router always runs at half / full precision for now.
If there are different weight scales between QKV layer (for separate QKV weights), they are re-quantized using layer.weight_scale.max() so we can have a single gemm for performance.
This commit is contained in:
Michael Goin 2024-05-04 14:45:16 -04:00 committed by GitHub
parent 36fb68f947
commit 2a052011ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 120 additions and 51 deletions

View File

@ -77,8 +77,8 @@ def test_mixtral_moe(dtype: torch.dtype):
for i in range(config.num_local_experts): for i in range(config.num_local_experts):
weights = (hf_moe.experts[i].w1.weight.data, weights = (hf_moe.experts[i].w1.weight.data,
hf_moe.experts[i].w3.weight.data) hf_moe.experts[i].w3.weight.data)
vllm_moe.ws[i][:] = torch.cat(weights, dim=0) vllm_moe.w13_weight[i][:] = torch.cat(weights, dim=0)
vllm_moe.w2s[i][:] = hf_moe.experts[i].w2.weight.data vllm_moe.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim] # Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda") hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")

View File

@ -78,6 +78,8 @@ class MixtralMoE(nn.Module):
self.top_k = top_k self.top_k = top_k
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.intermediate_size = intermediate_size // self.tp_size self.intermediate_size = intermediate_size // self.tp_size
self.quant_config = quant_config
# FIXME(pcmoritz): Make this more general to support different # FIXME(pcmoritz): Make this more general to support different
# quantization schemes # quantization schemes
self.use_fp8 = isinstance(quant_config, Fp8Config) self.use_fp8 = isinstance(quant_config, Fp8Config)
@ -86,55 +88,79 @@ class MixtralMoE(nn.Module):
params_dtype = torch.get_default_dtype() params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype self.params_dtype = params_dtype
# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(self.hidden_size, self.gate = ReplicatedLinear(self.hidden_size,
self.num_total_experts, self.num_total_experts,
bias=False, bias=False,
params_dtype=self.params_dtype, params_dtype=self.params_dtype,
quant_config=None) quant_config=None)
self.ws = nn.Parameter( if self.use_fp8:
params_dtype = torch.float8_e4m3fn
self.w13_weight = nn.Parameter(
torch.empty(self.num_total_experts, torch.empty(self.num_total_experts,
2 * self.intermediate_size, 2 * self.intermediate_size,
self.hidden_size, self.hidden_size,
dtype=self.params_dtype)) dtype=params_dtype))
self.w2s = nn.Parameter( self.w2_weight = nn.Parameter(
torch.empty(self.num_total_experts, torch.empty(self.num_total_experts,
self.hidden_size, self.hidden_size,
self.intermediate_size, self.intermediate_size,
dtype=self.params_dtype)) dtype=params_dtype))
set_weight_attrs(self.ws, { set_weight_attrs(self.w13_weight, {
"weight_loader": self.weight_loader, "weight_loader": self.weight_loader,
}) })
set_weight_attrs(self.w2s, { set_weight_attrs(self.w2_weight, {
"weight_loader": self.weight_loader, "weight_loader": self.weight_loader,
}) })
# Scaling factors for FP8 weights # Used for fp8.
self.ws_scale = nn.Parameter( self.w13_scale = None
torch.ones(self.num_total_experts, dtype=torch.float32), self.w2_scale = None
requires_grad=False) if self.use_fp8 else None self.a13_scale = None
self.w2s_scale = nn.Parameter( self.a2_scale = None
torch.ones(self.num_total_experts, dtype=torch.float32),
requires_grad=False) if self.use_fp8 else None
# Scaling factors for FP8 activations if self.use_fp8:
need_act_scales = (self.use_fp8 # WEIGHT_SCALE (for fp8)
and quant_config.activation_scheme == "static") self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
self.as_scale = nn.Parameter( dtype=torch.float32),
torch.zeros(1, dtype=torch.float32), requires_grad=False)
requires_grad=False) if need_act_scales else None self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
self.a2s_scale = nn.Parameter( dtype=torch.float32),
torch.zeros(1, dtype=torch.float32), requires_grad=False)
requires_grad=False) if need_act_scales else None
if need_act_scales: # If loading fp8 checkpoint, pass the weight loaders.
set_weight_attrs(self.as_scale, { # If loading an fp16 checkpoint, do not (we will quantize in
"weight_loader": self.weight_loader, # process_weights_after_loading()
}) if quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(self.a2s_scale, { set_weight_attrs(self.w13_scale, {
"weight_loader": self.weight_loader, "weight_loader": self.weight_loader,
}) })
set_weight_attrs(self.w2_scale, {
"weight_loader": self.weight_loader,
})
# ACT_SCALE (for fp8)
if quant_config.activation_scheme == "static":
if not quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
"Found static activation scheme for checkpoint that "
"was not serialized fp8.")
self.a13_scale = nn.Parameter(torch.zeros(
self.num_total_experts, dtype=torch.float32),
requires_grad=False)
self.a2_scale = nn.Parameter(torch.zeros(
self.num_total_experts, dtype=torch.float32),
requires_grad=False)
set_weight_attrs(self.a13_scale, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.a2_scale, {
"weight_loader": self.weight_loader,
})
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
weight_name: str, expert_id: int): weight_name: str, expert_id: int):
@ -149,20 +175,49 @@ class MixtralMoE(nn.Module):
shard_size:2 * shard_size, :] = loaded_weight[shard, :] shard_size:2 * shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w2.weight"): if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard] param_data[expert_id, :, :] = loaded_weight[:, shard]
if "act_scale" in weight_name: if "act_scale" in weight_name or "weight_scale" in weight_name:
param_data[:] = param_data[:].max(loaded_weight) param_data[expert_id] = loaded_weight
def process_weights_after_loading(self): def process_weights_after_loading(self):
if self.use_fp8: # Fp8 is the only case where we need to process after loading.
ws = torch.empty_like(self.ws.data, dtype=torch.float8_e4m3fn) if not self.use_fp8:
w2s = torch.empty_like(self.w2s.data, dtype=torch.float8_e4m3fn) return
# If checkpoint is fp16, quantize here.
if not self.quant_config.is_checkpoint_fp8_serialized:
w13_weight = torch.empty_like(self.w13_weight.data,
dtype=torch.float8_e4m3fn)
w2_weight = torch.empty_like(self.w2_weight.data,
dtype=torch.float8_e4m3fn)
for expert in range(self.num_total_experts): for expert in range(self.num_total_experts):
ws[expert, :, :], self.ws_scale[expert] = ops.scaled_fp8_quant( w13_weight[expert, :, :], self.w13_scale[
self.ws.data[expert, :, :]) expert] = ops.scaled_fp8_quant(
w2s[expert, :, :], self.w2s_scale[ self.w13_weight.data[expert, :, :])
expert] = ops.scaled_fp8_quant(self.w2s.data[expert, :, :]) w2_weight[expert, :, :], self.w2_scale[
self.ws = nn.Parameter(ws, requires_grad=False) expert] = ops.scaled_fp8_quant(
self.w2s = nn.Parameter(w2s, requires_grad=False) self.w2_weight.data[expert, :, :])
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
# If checkpoint is fp8 + static, cleanup act_scales.
# Since state_dict has an act_scale per expert but our kernels
# are passed one act_scale shared across all experts.
elif self.quant_config.activation_scheme == "static":
if self.a13_scale is None or self.a2_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None.")
if (not all_close_1d(self.a13_scale)
or not all_close_1d(self.a2_scale)):
print_warning_once(
"Found act_scales that are not equal for fp8 MoE layer. "
"Using the maximum across experts for each layer. ")
self.a13_scale = nn.Parameter(self.a13_scale.max(),
requires_grad=False)
self.a2_scale = nn.Parameter(self.a2_scale.max(),
requires_grad=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape num_tokens, hidden_size = hidden_states.shape
@ -170,17 +225,17 @@ class MixtralMoE(nn.Module):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states, final_hidden_states = fused_moe(hidden_states,
self.ws, self.w13_weight,
self.w2s, self.w2_weight,
router_logits, router_logits,
self.top_k, self.top_k,
renormalize=True, renormalize=True,
inplace=True, inplace=True,
use_fp8=self.use_fp8, use_fp8=self.use_fp8,
w1_scale=self.ws_scale, w1_scale=self.w13_scale,
w2_scale=self.w2s_scale, w2_scale=self.w2_scale,
a1_scale=self.as_scale, a1_scale=self.a13_scale,
a2_scale=self.a2s_scale) a2_scale=self.a2_scale)
if self.tp_size > 1: if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = tensor_model_parallel_all_reduce(
@ -222,7 +277,9 @@ class MixtralAttention(nn.Module):
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.sliding_window = sliding_window self.sliding_window = sliding_window
if isinstance(quant_config, Fp8Config): if isinstance(
quant_config,
Fp8Config) and not quant_config.is_checkpoint_fp8_serialized:
print_warning_once( print_warning_once(
"For Mixtral FP8 quantization, we currently do not quantize " "For Mixtral FP8 quantization, we currently do not quantize "
"the attention layers until their FP8 performance is improved." "the attention layers until their FP8 performance is improved."
@ -461,16 +518,23 @@ class MixtralForCausalLM(nn.Module):
] ]
expert_params_mapping = [ expert_params_mapping = [
# These are the weight scales for the experts
# (param_name, weight_name, expert_id)
("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
] + [
# These are the weights for the experts # These are the weights for the experts
# (param_name, weight_name, expert_id) # (param_name, weight_name, expert_id)
("ws" if weight_name in ["w1", "w3"] else "w2s", ("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
f"experts.{expert_id}.{weight_name}.weight", expert_id) f"experts.{expert_id}.{weight_name}.weight", expert_id)
for expert_id in range(self.config.num_local_experts) for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"] for weight_name in ["w1", "w2", "w3"]
] + [ ] + [
# These are the activation scales for the experts # These are the activation scales for the experts
# (param_name, weight_name, expert_id) # (param_name, weight_name, expert_id)
("as_scale" if weight_name in ["w1", "w3"] else "a2s_scale", ("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
f"experts.{expert_id}.{weight_name}.act_scale", expert_id) f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
for expert_id in range(self.config.num_local_experts) for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"] for weight_name in ["w1", "w2", "w3"]
@ -512,3 +576,8 @@ class MixtralForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
def all_close_1d(x: torch.Tensor) -> bool:
assert len(x.shape) == 1
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))