[Kernel] Support MoE Fp8 Checkpoints for Mixtral (Static Weights with Dynamic/Static Activations) (#4527)
Follow on to #4332 to enable FP8 checkpoint loading for Mixtral and supersedes #4436. This PR enables the following checkpoint loading features for Mixtral: Supports loading fp8 checkpoints for Mixtral, such as this "nm-testing/Mixtral-8x7B-Instruct-v0.1-FP8" test model Supports static or dynamic activation quantization with static weight quantization (all per tensor) Supports different scales for each expert weight Supports Fp8 in QKV layer Notes: The Expert Gate/Router always runs at half / full precision for now. If there are different weight scales between QKV layer (for separate QKV weights), they are re-quantized using layer.weight_scale.max() so we can have a single gemm for performance.
This commit is contained in:
parent
36fb68f947
commit
2a052011ca
@ -77,8 +77,8 @@ def test_mixtral_moe(dtype: torch.dtype):
|
|||||||
for i in range(config.num_local_experts):
|
for i in range(config.num_local_experts):
|
||||||
weights = (hf_moe.experts[i].w1.weight.data,
|
weights = (hf_moe.experts[i].w1.weight.data,
|
||||||
hf_moe.experts[i].w3.weight.data)
|
hf_moe.experts[i].w3.weight.data)
|
||||||
vllm_moe.ws[i][:] = torch.cat(weights, dim=0)
|
vllm_moe.w13_weight[i][:] = torch.cat(weights, dim=0)
|
||||||
vllm_moe.w2s[i][:] = hf_moe.experts[i].w2.weight.data
|
vllm_moe.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
|
||||||
|
|
||||||
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
|
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
|
||||||
hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
|
hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
|
||||||
|
@ -78,6 +78,8 @@ class MixtralMoE(nn.Module):
|
|||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.intermediate_size = intermediate_size // self.tp_size
|
self.intermediate_size = intermediate_size // self.tp_size
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
# FIXME(pcmoritz): Make this more general to support different
|
# FIXME(pcmoritz): Make this more general to support different
|
||||||
# quantization schemes
|
# quantization schemes
|
||||||
self.use_fp8 = isinstance(quant_config, Fp8Config)
|
self.use_fp8 = isinstance(quant_config, Fp8Config)
|
||||||
@ -86,53 +88,77 @@ class MixtralMoE(nn.Module):
|
|||||||
params_dtype = torch.get_default_dtype()
|
params_dtype = torch.get_default_dtype()
|
||||||
self.params_dtype = params_dtype
|
self.params_dtype = params_dtype
|
||||||
|
|
||||||
|
# Gate always runs at half / full precision for now.
|
||||||
self.gate = ReplicatedLinear(self.hidden_size,
|
self.gate = ReplicatedLinear(self.hidden_size,
|
||||||
self.num_total_experts,
|
self.num_total_experts,
|
||||||
bias=False,
|
bias=False,
|
||||||
params_dtype=self.params_dtype,
|
params_dtype=self.params_dtype,
|
||||||
quant_config=None)
|
quant_config=None)
|
||||||
|
|
||||||
self.ws = nn.Parameter(
|
if self.use_fp8:
|
||||||
|
params_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
self.w13_weight = nn.Parameter(
|
||||||
torch.empty(self.num_total_experts,
|
torch.empty(self.num_total_experts,
|
||||||
2 * self.intermediate_size,
|
2 * self.intermediate_size,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
dtype=self.params_dtype))
|
dtype=params_dtype))
|
||||||
self.w2s = nn.Parameter(
|
self.w2_weight = nn.Parameter(
|
||||||
torch.empty(self.num_total_experts,
|
torch.empty(self.num_total_experts,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
self.intermediate_size,
|
self.intermediate_size,
|
||||||
dtype=self.params_dtype))
|
dtype=params_dtype))
|
||||||
|
|
||||||
set_weight_attrs(self.ws, {
|
set_weight_attrs(self.w13_weight, {
|
||||||
"weight_loader": self.weight_loader,
|
"weight_loader": self.weight_loader,
|
||||||
})
|
})
|
||||||
set_weight_attrs(self.w2s, {
|
set_weight_attrs(self.w2_weight, {
|
||||||
"weight_loader": self.weight_loader,
|
"weight_loader": self.weight_loader,
|
||||||
})
|
})
|
||||||
|
|
||||||
# Scaling factors for FP8 weights
|
# Used for fp8.
|
||||||
self.ws_scale = nn.Parameter(
|
self.w13_scale = None
|
||||||
torch.ones(self.num_total_experts, dtype=torch.float32),
|
self.w2_scale = None
|
||||||
requires_grad=False) if self.use_fp8 else None
|
self.a13_scale = None
|
||||||
self.w2s_scale = nn.Parameter(
|
self.a2_scale = None
|
||||||
torch.ones(self.num_total_experts, dtype=torch.float32),
|
|
||||||
requires_grad=False) if self.use_fp8 else None
|
|
||||||
|
|
||||||
# Scaling factors for FP8 activations
|
if self.use_fp8:
|
||||||
need_act_scales = (self.use_fp8
|
# WEIGHT_SCALE (for fp8)
|
||||||
and quant_config.activation_scheme == "static")
|
self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
||||||
self.as_scale = nn.Parameter(
|
dtype=torch.float32),
|
||||||
torch.zeros(1, dtype=torch.float32),
|
requires_grad=False)
|
||||||
requires_grad=False) if need_act_scales else None
|
self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
|
||||||
self.a2s_scale = nn.Parameter(
|
dtype=torch.float32),
|
||||||
torch.zeros(1, dtype=torch.float32),
|
requires_grad=False)
|
||||||
requires_grad=False) if need_act_scales else None
|
|
||||||
|
|
||||||
if need_act_scales:
|
# If loading fp8 checkpoint, pass the weight loaders.
|
||||||
set_weight_attrs(self.as_scale, {
|
# If loading an fp16 checkpoint, do not (we will quantize in
|
||||||
|
# process_weights_after_loading()
|
||||||
|
if quant_config.is_checkpoint_fp8_serialized:
|
||||||
|
set_weight_attrs(self.w13_scale, {
|
||||||
"weight_loader": self.weight_loader,
|
"weight_loader": self.weight_loader,
|
||||||
})
|
})
|
||||||
set_weight_attrs(self.a2s_scale, {
|
set_weight_attrs(self.w2_scale, {
|
||||||
|
"weight_loader": self.weight_loader,
|
||||||
|
})
|
||||||
|
|
||||||
|
# ACT_SCALE (for fp8)
|
||||||
|
if quant_config.activation_scheme == "static":
|
||||||
|
if not quant_config.is_checkpoint_fp8_serialized:
|
||||||
|
raise ValueError(
|
||||||
|
"Found static activation scheme for checkpoint that "
|
||||||
|
"was not serialized fp8.")
|
||||||
|
self.a13_scale = nn.Parameter(torch.zeros(
|
||||||
|
self.num_total_experts, dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
self.a2_scale = nn.Parameter(torch.zeros(
|
||||||
|
self.num_total_experts, dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
set_weight_attrs(self.a13_scale, {
|
||||||
|
"weight_loader": self.weight_loader,
|
||||||
|
})
|
||||||
|
set_weight_attrs(self.a2_scale, {
|
||||||
"weight_loader": self.weight_loader,
|
"weight_loader": self.weight_loader,
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -149,20 +175,49 @@ class MixtralMoE(nn.Module):
|
|||||||
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
|
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
|
||||||
if weight_name.endswith("w2.weight"):
|
if weight_name.endswith("w2.weight"):
|
||||||
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
||||||
if "act_scale" in weight_name:
|
if "act_scale" in weight_name or "weight_scale" in weight_name:
|
||||||
param_data[:] = param_data[:].max(loaded_weight)
|
param_data[expert_id] = loaded_weight
|
||||||
|
|
||||||
def process_weights_after_loading(self):
|
def process_weights_after_loading(self):
|
||||||
if self.use_fp8:
|
# Fp8 is the only case where we need to process after loading.
|
||||||
ws = torch.empty_like(self.ws.data, dtype=torch.float8_e4m3fn)
|
if not self.use_fp8:
|
||||||
w2s = torch.empty_like(self.w2s.data, dtype=torch.float8_e4m3fn)
|
return
|
||||||
|
|
||||||
|
# If checkpoint is fp16, quantize here.
|
||||||
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
|
w13_weight = torch.empty_like(self.w13_weight.data,
|
||||||
|
dtype=torch.float8_e4m3fn)
|
||||||
|
w2_weight = torch.empty_like(self.w2_weight.data,
|
||||||
|
dtype=torch.float8_e4m3fn)
|
||||||
for expert in range(self.num_total_experts):
|
for expert in range(self.num_total_experts):
|
||||||
ws[expert, :, :], self.ws_scale[expert] = ops.scaled_fp8_quant(
|
w13_weight[expert, :, :], self.w13_scale[
|
||||||
self.ws.data[expert, :, :])
|
expert] = ops.scaled_fp8_quant(
|
||||||
w2s[expert, :, :], self.w2s_scale[
|
self.w13_weight.data[expert, :, :])
|
||||||
expert] = ops.scaled_fp8_quant(self.w2s.data[expert, :, :])
|
w2_weight[expert, :, :], self.w2_scale[
|
||||||
self.ws = nn.Parameter(ws, requires_grad=False)
|
expert] = ops.scaled_fp8_quant(
|
||||||
self.w2s = nn.Parameter(w2s, requires_grad=False)
|
self.w2_weight.data[expert, :, :])
|
||||||
|
self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
|
||||||
|
self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
|
||||||
|
|
||||||
|
# If checkpoint is fp8 + static, cleanup act_scales.
|
||||||
|
# Since state_dict has an act_scale per expert but our kernels
|
||||||
|
# are passed one act_scale shared across all experts.
|
||||||
|
elif self.quant_config.activation_scheme == "static":
|
||||||
|
if self.a13_scale is None or self.a2_scale is None:
|
||||||
|
raise ValueError(
|
||||||
|
"QuantConfig has static quantization, but found "
|
||||||
|
"activation scales are None.")
|
||||||
|
|
||||||
|
if (not all_close_1d(self.a13_scale)
|
||||||
|
or not all_close_1d(self.a2_scale)):
|
||||||
|
print_warning_once(
|
||||||
|
"Found act_scales that are not equal for fp8 MoE layer. "
|
||||||
|
"Using the maximum across experts for each layer. ")
|
||||||
|
|
||||||
|
self.a13_scale = nn.Parameter(self.a13_scale.max(),
|
||||||
|
requires_grad=False)
|
||||||
|
self.a2_scale = nn.Parameter(self.a2_scale.max(),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
num_tokens, hidden_size = hidden_states.shape
|
num_tokens, hidden_size = hidden_states.shape
|
||||||
@ -170,17 +225,17 @@ class MixtralMoE(nn.Module):
|
|||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
final_hidden_states = fused_moe(hidden_states,
|
final_hidden_states = fused_moe(hidden_states,
|
||||||
self.ws,
|
self.w13_weight,
|
||||||
self.w2s,
|
self.w2_weight,
|
||||||
router_logits,
|
router_logits,
|
||||||
self.top_k,
|
self.top_k,
|
||||||
renormalize=True,
|
renormalize=True,
|
||||||
inplace=True,
|
inplace=True,
|
||||||
use_fp8=self.use_fp8,
|
use_fp8=self.use_fp8,
|
||||||
w1_scale=self.ws_scale,
|
w1_scale=self.w13_scale,
|
||||||
w2_scale=self.w2s_scale,
|
w2_scale=self.w2_scale,
|
||||||
a1_scale=self.as_scale,
|
a1_scale=self.a13_scale,
|
||||||
a2_scale=self.a2s_scale)
|
a2_scale=self.a2_scale)
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
@ -222,7 +277,9 @@ class MixtralAttention(nn.Module):
|
|||||||
self.rope_theta = rope_theta
|
self.rope_theta = rope_theta
|
||||||
self.sliding_window = sliding_window
|
self.sliding_window = sliding_window
|
||||||
|
|
||||||
if isinstance(quant_config, Fp8Config):
|
if isinstance(
|
||||||
|
quant_config,
|
||||||
|
Fp8Config) and not quant_config.is_checkpoint_fp8_serialized:
|
||||||
print_warning_once(
|
print_warning_once(
|
||||||
"For Mixtral FP8 quantization, we currently do not quantize "
|
"For Mixtral FP8 quantization, we currently do not quantize "
|
||||||
"the attention layers until their FP8 performance is improved."
|
"the attention layers until their FP8 performance is improved."
|
||||||
@ -461,16 +518,23 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
]
|
]
|
||||||
|
|
||||||
expert_params_mapping = [
|
expert_params_mapping = [
|
||||||
|
# These are the weight scales for the experts
|
||||||
|
# (param_name, weight_name, expert_id)
|
||||||
|
("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
|
||||||
|
f"experts.{expert_id}.{weight_name}.weight_scale", expert_id)
|
||||||
|
for expert_id in range(self.config.num_local_experts)
|
||||||
|
for weight_name in ["w1", "w2", "w3"]
|
||||||
|
] + [
|
||||||
# These are the weights for the experts
|
# These are the weights for the experts
|
||||||
# (param_name, weight_name, expert_id)
|
# (param_name, weight_name, expert_id)
|
||||||
("ws" if weight_name in ["w1", "w3"] else "w2s",
|
("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
|
||||||
f"experts.{expert_id}.{weight_name}.weight", expert_id)
|
f"experts.{expert_id}.{weight_name}.weight", expert_id)
|
||||||
for expert_id in range(self.config.num_local_experts)
|
for expert_id in range(self.config.num_local_experts)
|
||||||
for weight_name in ["w1", "w2", "w3"]
|
for weight_name in ["w1", "w2", "w3"]
|
||||||
] + [
|
] + [
|
||||||
# These are the activation scales for the experts
|
# These are the activation scales for the experts
|
||||||
# (param_name, weight_name, expert_id)
|
# (param_name, weight_name, expert_id)
|
||||||
("as_scale" if weight_name in ["w1", "w3"] else "a2s_scale",
|
("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
|
||||||
f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
|
f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
|
||||||
for expert_id in range(self.config.num_local_experts)
|
for expert_id in range(self.config.num_local_experts)
|
||||||
for weight_name in ["w1", "w2", "w3"]
|
for weight_name in ["w1", "w2", "w3"]
|
||||||
@ -512,3 +576,8 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
|
||||||
|
def all_close_1d(x: torch.Tensor) -> bool:
|
||||||
|
assert len(x.shape) == 1
|
||||||
|
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user