[Kernel] Optimize FP8 support for MoE kernel / Mixtral via static scales (#4343)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
258a2c58d0
commit
12628d3c78
@ -146,7 +146,12 @@ void gptq_shuffle(
|
||||
torch::Tensor q_perm,
|
||||
int bit);
|
||||
|
||||
void scaled_fp8_quant(
|
||||
void static_scaled_fp8_quant(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& scale);
|
||||
|
||||
void dynamic_scaled_fp8_quant(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& scale);
|
||||
|
@ -73,7 +73,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
|
||||
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
|
||||
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
||||
ops.def("scaled_fp8_quant", &scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
|
||||
ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, "Compute FP8 quantized tensor for given scaling factor");
|
||||
ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
|
||||
ops.def(
|
||||
"moe_align_block_size",
|
||||
&moe_align_block_size,
|
||||
|
@ -74,7 +74,30 @@ __global__ void scaled_fp8_quant_kernel(
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void scaled_fp8_quant(
|
||||
void static_scaled_fp8_quant(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input, // [..., d]
|
||||
torch::Tensor& scale) // [1]
|
||||
{
|
||||
int64_t num_tokens = input.numel() / input.size(-1);
|
||||
int64_t num_elems = input.numel();
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(1024);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(),
|
||||
"scaled_fp8_quant_kernel",
|
||||
[&] {
|
||||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<c10::Float8_e4m3fn>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(),
|
||||
num_elems);
|
||||
});
|
||||
}
|
||||
|
||||
void dynamic_scaled_fp8_quant(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input, // [..., d]
|
||||
torch::Tensor& scale) // [1]
|
||||
|
@ -168,10 +168,16 @@ def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
|
||||
|
||||
|
||||
# fp8
|
||||
def scaled_fp8_quant(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||
def scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
|
||||
vllm_ops.scaled_fp8_quant(output, input, scale)
|
||||
if scale is None:
|
||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||
vllm_ops.dynamic_scaled_fp8_quant(output, input, scale)
|
||||
else:
|
||||
vllm_ops.static_scaled_fp8_quant(output, input, scale)
|
||||
return output, scale
|
||||
|
||||
|
||||
|
@ -220,8 +220,9 @@ def moe_align_block_size(
|
||||
|
||||
|
||||
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
||||
B_scale: torch.Tensor, topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
A_scale: Optional[torch.Tensor],
|
||||
B_scale: Optional[torch.Tensor],
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
sorted_token_ids: torch.Tensor,
|
||||
expert_ids: torch.Tensor,
|
||||
num_tokens_post_padded: torch.Tensor,
|
||||
@ -232,10 +233,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
|
||||
if not use_fp8:
|
||||
A_scale = None
|
||||
assert A_scale is None
|
||||
assert B_scale is None
|
||||
else:
|
||||
A, A_scale = ops.scaled_fp8_quant(A)
|
||||
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
||||
assert B_scale is not None
|
||||
|
||||
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
|
||||
@ -318,6 +319,8 @@ def fused_moe(
|
||||
use_fp8: bool = False,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
@ -434,6 +437,7 @@ def fused_moe(
|
||||
invoke_fused_moe_kernel(hidden_states,
|
||||
w1,
|
||||
intermediate_cache1,
|
||||
a1_scale,
|
||||
w1_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
@ -451,6 +455,7 @@ def fused_moe(
|
||||
invoke_fused_moe_kernel(intermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
a2_scale,
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
|
@ -14,6 +14,12 @@ from vllm.model_executor.utils import set_weight_attrs
|
||||
class Fp8Config(QuantizationConfig):
|
||||
"""Config class for FP8."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
activation_scheme: str = "dynamic",
|
||||
) -> None:
|
||||
self.activation_scheme = activation_scheme
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "fp8"
|
||||
@ -35,7 +41,8 @@ class Fp8Config(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
|
||||
return cls()
|
||||
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
||||
return cls(activation_scheme)
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
|
||||
|
@ -105,6 +105,13 @@ class MixtralMoE(nn.Module):
|
||||
device="cuda",
|
||||
dtype=self.params_dtype))
|
||||
|
||||
set_weight_attrs(self.ws, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
set_weight_attrs(self.w2s, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
|
||||
# Scaling factors for FP8 weights
|
||||
self.ws_scale = nn.Parameter(
|
||||
torch.ones(
|
||||
@ -115,12 +122,23 @@ class MixtralMoE(nn.Module):
|
||||
self.num_total_experts, device="cuda", dtype=torch.float32),
|
||||
requires_grad=False) if self.use_fp8 else None
|
||||
|
||||
set_weight_attrs(self.ws, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
set_weight_attrs(self.w2s, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
# Scaling factors for FP8 activations
|
||||
need_act_scales = (self.use_fp8
|
||||
and quant_config.activation_scheme == "static")
|
||||
self.as_scale = nn.Parameter(
|
||||
torch.zeros(1, device="cuda", dtype=torch.float32),
|
||||
requires_grad=False) if need_act_scales else None
|
||||
self.a2s_scale = nn.Parameter(
|
||||
torch.zeros(1, device="cuda", dtype=torch.float32),
|
||||
requires_grad=False) if need_act_scales else None
|
||||
|
||||
if need_act_scales:
|
||||
set_weight_attrs(self.as_scale, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
set_weight_attrs(self.a2s_scale, {
|
||||
"weight_loader": self.weight_loader,
|
||||
})
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
|
||||
weight_name: str, expert_id: int):
|
||||
@ -135,6 +153,8 @@ class MixtralMoE(nn.Module):
|
||||
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
|
||||
if weight_name.endswith("w2.weight"):
|
||||
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
||||
if "act_scale" in weight_name:
|
||||
param_data[:] = param_data[:].max(loaded_weight)
|
||||
|
||||
def process_weights_after_loading(self):
|
||||
if self.use_fp8:
|
||||
@ -162,7 +182,9 @@ class MixtralMoE(nn.Module):
|
||||
inplace=True,
|
||||
use_fp8=self.use_fp8,
|
||||
w1_scale=self.ws_scale,
|
||||
w2_scale=self.w2s_scale)
|
||||
w2_scale=self.w2s_scale,
|
||||
a1_scale=self.as_scale,
|
||||
a2_scale=self.a2s_scale)
|
||||
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
@ -443,11 +465,19 @@ class MixtralForCausalLM(nn.Module):
|
||||
]
|
||||
|
||||
expert_params_mapping = [
|
||||
# These are the weights for the experts
|
||||
# (param_name, weight_name, expert_id)
|
||||
("ws" if weight_name in ["w1", "w3"] else "w2s",
|
||||
f"experts.{expert_id}.{weight_name}.weight", expert_id)
|
||||
for expert_id in range(self.config.num_local_experts)
|
||||
for weight_name in ["w1", "w2", "w3"]
|
||||
] + [
|
||||
# These are the activation scales for the experts
|
||||
# (param_name, weight_name, expert_id)
|
||||
("as_scale" if weight_name in ["w1", "w3"] else "a2s_scale",
|
||||
f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
|
||||
for expert_id in range(self.config.num_local_experts)
|
||||
for weight_name in ["w1", "w2", "w3"]
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
Loading…
x
Reference in New Issue
Block a user