[Kernel] Support MoE Fp8 Checkpoints for Mixtral (Static Weights with Dynamic/Static Activations) (#4527)

Follow on to #4332 to enable FP8 checkpoint loading for Mixtral and supersedes #4436.

This PR enables the following checkpoint loading features for Mixtral:

Supports loading fp8 checkpoints for Mixtral, such as this "nm-testing/Mixtral-8x7B-Instruct-v0.1-FP8" test model
Supports static or dynamic activation quantization with static weight quantization (all per tensor)
Supports different scales for each expert weight
Supports Fp8 in QKV layer
Notes:

The Expert Gate/Router always runs at half / full precision for now.
If there are different weight scales between QKV layer (for separate QKV weights), they are re-quantized using layer.weight_scale.max() so we can have a single gemm for performance.
This commit is contained in:
Michael Goin 2024-05-04 14:45:16 -04:00 committed by GitHub
parent 36fb68f947
commit 2a052011ca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 120 additions and 51 deletions

View File

@ -77,8 +77,8 @@ def test_mixtral_moe(dtype: torch.dtype):
for i in range(config.num_local_experts):
weights = (hf_moe.experts[i].w1.weight.data,
hf_moe.experts[i].w3.weight.data)
vllm_moe.ws[i][:] = torch.cat(weights, dim=0)
vllm_moe.w2s[i][:] = hf_moe.experts[i].w2.weight.data
vllm_moe.w13_weight[i][:] = torch.cat(weights, dim=0)
vllm_moe.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")

View File

@ -78,6 +78,8 @@ class MixtralMoE(nn.Module):
self.top_k = top_k
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size // self.tp_size
self.quant_config = quant_config
# FIXME(pcmoritz): Make this more general to support different
# quantization schemes
self.use_fp8 = isinstance(quant_config, Fp8Config)
@ -86,55 +88,79 @@ class MixtralMoE(nn.Module):
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(self.hidden_size,
self.num_total_experts,
bias=False,
params_dtype=self.params_dtype,
quant_config=None)
self.ws = nn.Parameter(
if self.use_fp8:
params_dtype = torch.float8_e4m3fn
self.w13_weight = nn.Parameter(
torch.empty(self.num_total_experts,
2 * self.intermediate_size,
self.hidden_size,
dtype=self.params_dtype))
self.w2s = nn.Parameter(
dtype=params_dtype))
self.w2_weight = nn.Parameter(
torch.empty(self.num_total_experts,
self.hidden_size,
self.intermediate_size,
dtype=self.params_dtype))
dtype=params_dtype))
set_weight_attrs(self.ws, {
set_weight_attrs(self.w13_weight, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.w2s, {
set_weight_attrs(self.w2_weight, {
"weight_loader": self.weight_loader,
})
# Scaling factors for FP8 weights
self.ws_scale = nn.Parameter(
torch.ones(self.num_total_experts, dtype=torch.float32),
requires_grad=False) if self.use_fp8 else None
self.w2s_scale = nn.Parameter(
torch.ones(self.num_total_experts, dtype=torch.float32),
requires_grad=False) if self.use_fp8 else None
# Used for fp8.
self.w13_scale = None
self.w2_scale = None
self.a13_scale = None
self.a2_scale = None
# Scaling factors for FP8 activations
need_act_scales = (self.use_fp8
and quant_config.activation_scheme == "static")
self.as_scale = nn.Parameter(
torch.zeros(1, dtype=torch.float32),
requires_grad=False) if need_act_scales else None
self.a2s_scale = nn.Parameter(
torch.zeros(1, dtype=torch.float32),
requires_grad=False) if need_act_scales else None
if self.use_fp8:
# WEIGHT_SCALE (for fp8)
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
dtype=torch.float32),
requires_grad=False)
self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
dtype=torch.float32),
requires_grad=False)
if need_act_scales:
set_weight_attrs(self.as_scale, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.a2s_scale, {
"weight_loader": self.weight_loader,
})
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if quant_config.is_checkpoint_fp8_serialized:
set_weight_attrs(self.w13_scale, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.w2_scale, {
"weight_loader": self.weight_loader,
})
# ACT_SCALE (for fp8)
if quant_config.activation_scheme == "static":
if not quant_config.is_checkpoint_fp8_serialized:
raise ValueError(
"Found static activation scheme for checkpoint that "
"was not serialized fp8.")
self.a13_scale = nn.Parameter(torch.zeros(
self.num_total_experts, dtype=torch.float32),
requires_grad=False)
self.a2_scale = nn.Parameter(torch.zeros(
self.num_total_experts, dtype=torch.float32),
requires_grad=False)
set_weight_attrs(self.a13_scale, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.a2_scale, {
"weight_loader": self.weight_loader,
})
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
weight_name: str, expert_id: int):
@ -149,20 +175,49 @@ class MixtralMoE(nn.Module):
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard]
if "act_scale" in weight_name:
param_data[:] = param_data[:].max(loaded_weight)
if "act_scale" in weight_name or "weight_scale" in weight_name:
param_data[expert_id] = loaded_weight
def process_weights_after_loading(self):
if self.use_fp8:
ws = torch.empty_like(self.ws.data, dtype=torch.float8_e4m3fn)
w2s = torch.empty_like(self.w2s.data, dtype=torch.float8_e4m3fn)
# Fp8 is the only case where we need to process after loading.
if not self.use_fp8:
return
# If checkpoint is fp16, quantize here.
if not self.quant_config.is_checkpoint_fp8_serialized:
w13_weight = torch.empty_like(self.w13_weight.data,
dtype=torch.float8_e4m3fn)
w2_weight = torch.empty_like(self.w2_weight.data,
dtype=torch.float8_e4m3fn)
for expert in range(self.num_total_experts):
ws[expert, :, :], self.ws_scale[expert] = ops.scaled_fp8_quant(
self.ws.data[expert, :, :])
w2s[expert, :, :], self.w2s_scale[
expert] = ops.scaled_fp8_quant(self.w2s.data[expert, :, :])
self.ws = nn.Parameter(ws, requires_grad=False)
self.w2s = nn.Parameter(w2s, requires_grad=False)
w13_weight[expert, :, :], self.w13_scale[
expert] = ops.scaled_fp8_quant(
self.w13_weight.data[expert, :, :])
w2_weight[expert, :, :], self.w2_scale[
expert] = ops.scaled_fp8_quant(
self.w2_weight.data[expert, :, :])
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
# If checkpoint is fp8 + static, cleanup act_scales.
# Since state_dict has an act_scale per expert but our kernels
# are passed one act_scale shared across all experts.
elif self.quant_config.activation_scheme == "static":
if self.a13_scale is None or self.a2_scale is None:
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None.")
if (not all_close_1d(self.a13_scale)
or not all_close_1d(self.a2_scale)):
print_warning_once(
"Found act_scales that are not equal for fp8 MoE layer. "
"Using the maximum across experts for each layer. ")
self.a13_scale = nn.Parameter(self.a13_scale.max(),
requires_grad=False)
self.a2_scale = nn.Parameter(self.a2_scale.max(),
requires_grad=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
@ -170,17 +225,17 @@ class MixtralMoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = fused_moe(hidden_states,
self.ws,
self.w2s,
self.w13_weight,
self.w2_weight,
router_logits,
self.top_k,
renormalize=True,
inplace=True,
use_fp8=self.use_fp8,
w1_scale=self.ws_scale,
w2_scale=self.w2s_scale,
a1_scale=self.as_scale,
a2_scale=self.a2s_scale)
w1_scale=self.w13_scale,
w2_scale=self.w2_scale,
a1_scale=self.a13_scale,
a2_scale=self.a2_scale)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
@ -222,7 +277,9 @@ class MixtralAttention(nn.Module):
self.rope_theta = rope_theta
self.sliding_window = sliding_window
if isinstance(quant_config, Fp8Config):
if isinstance(
quant_config,
Fp8Config) and not quant_config.is_checkpoint_fp8_serialized:
print_warning_once(
"For Mixtral FP8 quantization, we currently do not quantize "
"the attention layers until their FP8 performance is improved."
@ -461,16 +518,23 @@ class MixtralForCausalLM(nn.Module):
]
expert_params_mapping = [
# These are the weight scales for the experts
# (param_name, weight_name, expert_id)
("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
] + [
# These are the weights for the experts
# (param_name, weight_name, expert_id)
("ws" if weight_name in ["w1", "w3"] else "w2s",
("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
f"experts.{expert_id}.{weight_name}.weight", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
] + [
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
("as_scale" if weight_name in ["w1", "w3"] else "a2s_scale",
("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
@ -512,3 +576,8 @@ class MixtralForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
def all_close_1d(x: torch.Tensor) -> bool:
assert len(x.shape) == 1
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))