[Kernel] Optimize FP8 support for MoE kernel / Mixtral via static scales (#4343)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
258a2c58d0
commit
12628d3c78
@ -146,7 +146,12 @@ void gptq_shuffle(
|
|||||||
torch::Tensor q_perm,
|
torch::Tensor q_perm,
|
||||||
int bit);
|
int bit);
|
||||||
|
|
||||||
void scaled_fp8_quant(
|
void static_scaled_fp8_quant(
|
||||||
|
torch::Tensor& out,
|
||||||
|
torch::Tensor& input,
|
||||||
|
torch::Tensor& scale);
|
||||||
|
|
||||||
|
void dynamic_scaled_fp8_quant(
|
||||||
torch::Tensor& out,
|
torch::Tensor& out,
|
||||||
torch::Tensor& input,
|
torch::Tensor& input,
|
||||||
torch::Tensor& scale);
|
torch::Tensor& scale);
|
||||||
|
@ -73,7 +73,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
|
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
|
||||||
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
|
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
|
||||||
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
||||||
ops.def("scaled_fp8_quant", &scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
|
ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, "Compute FP8 quantized tensor for given scaling factor");
|
||||||
|
ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
|
||||||
ops.def(
|
ops.def(
|
||||||
"moe_align_block_size",
|
"moe_align_block_size",
|
||||||
&moe_align_block_size,
|
&moe_align_block_size,
|
||||||
|
@ -74,7 +74,30 @@ __global__ void scaled_fp8_quant_kernel(
|
|||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
|
||||||
void scaled_fp8_quant(
|
void static_scaled_fp8_quant(
|
||||||
|
torch::Tensor& out, // [..., d]
|
||||||
|
torch::Tensor& input, // [..., d]
|
||||||
|
torch::Tensor& scale) // [1]
|
||||||
|
{
|
||||||
|
int64_t num_tokens = input.numel() / input.size(-1);
|
||||||
|
int64_t num_elems = input.numel();
|
||||||
|
dim3 grid(num_tokens);
|
||||||
|
dim3 block(1024);
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
|
input.scalar_type(),
|
||||||
|
"scaled_fp8_quant_kernel",
|
||||||
|
[&] {
|
||||||
|
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
|
out.data_ptr<c10::Float8_e4m3fn>(),
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
scale.data_ptr<float>(),
|
||||||
|
num_elems);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void dynamic_scaled_fp8_quant(
|
||||||
torch::Tensor& out, // [..., d]
|
torch::Tensor& out, // [..., d]
|
||||||
torch::Tensor& input, // [..., d]
|
torch::Tensor& input, // [..., d]
|
||||||
torch::Tensor& scale) // [1]
|
torch::Tensor& scale) // [1]
|
||||||
|
@ -168,10 +168,16 @@ def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor,
|
|||||||
|
|
||||||
|
|
||||||
# fp8
|
# fp8
|
||||||
def scaled_fp8_quant(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
def scaled_fp8_quant(
|
||||||
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
input: torch.Tensor,
|
||||||
|
scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
|
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
|
||||||
vllm_ops.scaled_fp8_quant(output, input, scale)
|
if scale is None:
|
||||||
|
scale = torch.zeros(1, device=input.device, dtype=torch.float32)
|
||||||
|
vllm_ops.dynamic_scaled_fp8_quant(output, input, scale)
|
||||||
|
else:
|
||||||
|
vllm_ops.static_scaled_fp8_quant(output, input, scale)
|
||||||
return output, scale
|
return output, scale
|
||||||
|
|
||||||
|
|
||||||
|
@ -220,8 +220,9 @@ def moe_align_block_size(
|
|||||||
|
|
||||||
|
|
||||||
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
||||||
B_scale: torch.Tensor, topk_weights: torch.Tensor,
|
A_scale: Optional[torch.Tensor],
|
||||||
topk_ids: torch.Tensor,
|
B_scale: Optional[torch.Tensor],
|
||||||
|
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||||
sorted_token_ids: torch.Tensor,
|
sorted_token_ids: torch.Tensor,
|
||||||
expert_ids: torch.Tensor,
|
expert_ids: torch.Tensor,
|
||||||
num_tokens_post_padded: torch.Tensor,
|
num_tokens_post_padded: torch.Tensor,
|
||||||
@ -232,10 +233,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
|
|||||||
assert sorted_token_ids.stride(0) == 1
|
assert sorted_token_ids.stride(0) == 1
|
||||||
|
|
||||||
if not use_fp8:
|
if not use_fp8:
|
||||||
A_scale = None
|
assert A_scale is None
|
||||||
assert B_scale is None
|
assert B_scale is None
|
||||||
else:
|
else:
|
||||||
A, A_scale = ops.scaled_fp8_quant(A)
|
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
|
||||||
assert B_scale is not None
|
assert B_scale is not None
|
||||||
|
|
||||||
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
|
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
|
||||||
@ -318,6 +319,8 @@ def fused_moe(
|
|||||||
use_fp8: bool = False,
|
use_fp8: bool = False,
|
||||||
w1_scale: Optional[torch.Tensor] = None,
|
w1_scale: Optional[torch.Tensor] = None,
|
||||||
w2_scale: Optional[torch.Tensor] = None,
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
|
a1_scale: Optional[torch.Tensor] = None,
|
||||||
|
a2_scale: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||||
@ -434,6 +437,7 @@ def fused_moe(
|
|||||||
invoke_fused_moe_kernel(hidden_states,
|
invoke_fused_moe_kernel(hidden_states,
|
||||||
w1,
|
w1,
|
||||||
intermediate_cache1,
|
intermediate_cache1,
|
||||||
|
a1_scale,
|
||||||
w1_scale,
|
w1_scale,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
@ -451,6 +455,7 @@ def fused_moe(
|
|||||||
invoke_fused_moe_kernel(intermediate_cache2,
|
invoke_fused_moe_kernel(intermediate_cache2,
|
||||||
w2,
|
w2,
|
||||||
intermediate_cache3,
|
intermediate_cache3,
|
||||||
|
a2_scale,
|
||||||
w2_scale,
|
w2_scale,
|
||||||
topk_weights,
|
topk_weights,
|
||||||
topk_ids,
|
topk_ids,
|
||||||
|
@ -14,6 +14,12 @@ from vllm.model_executor.utils import set_weight_attrs
|
|||||||
class Fp8Config(QuantizationConfig):
|
class Fp8Config(QuantizationConfig):
|
||||||
"""Config class for FP8."""
|
"""Config class for FP8."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
activation_scheme: str = "dynamic",
|
||||||
|
) -> None:
|
||||||
|
self.activation_scheme = activation_scheme
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(cls) -> str:
|
def get_name(cls) -> str:
|
||||||
return "fp8"
|
return "fp8"
|
||||||
@ -35,7 +41,8 @@ class Fp8Config(QuantizationConfig):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
|
def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
|
||||||
return cls()
|
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
||||||
|
return cls(activation_scheme)
|
||||||
|
|
||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
|
self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
|
||||||
|
@ -105,6 +105,13 @@ class MixtralMoE(nn.Module):
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
dtype=self.params_dtype))
|
dtype=self.params_dtype))
|
||||||
|
|
||||||
|
set_weight_attrs(self.ws, {
|
||||||
|
"weight_loader": self.weight_loader,
|
||||||
|
})
|
||||||
|
set_weight_attrs(self.w2s, {
|
||||||
|
"weight_loader": self.weight_loader,
|
||||||
|
})
|
||||||
|
|
||||||
# Scaling factors for FP8 weights
|
# Scaling factors for FP8 weights
|
||||||
self.ws_scale = nn.Parameter(
|
self.ws_scale = nn.Parameter(
|
||||||
torch.ones(
|
torch.ones(
|
||||||
@ -115,12 +122,23 @@ class MixtralMoE(nn.Module):
|
|||||||
self.num_total_experts, device="cuda", dtype=torch.float32),
|
self.num_total_experts, device="cuda", dtype=torch.float32),
|
||||||
requires_grad=False) if self.use_fp8 else None
|
requires_grad=False) if self.use_fp8 else None
|
||||||
|
|
||||||
set_weight_attrs(self.ws, {
|
# Scaling factors for FP8 activations
|
||||||
"weight_loader": self.weight_loader,
|
need_act_scales = (self.use_fp8
|
||||||
})
|
and quant_config.activation_scheme == "static")
|
||||||
set_weight_attrs(self.w2s, {
|
self.as_scale = nn.Parameter(
|
||||||
"weight_loader": self.weight_loader,
|
torch.zeros(1, device="cuda", dtype=torch.float32),
|
||||||
})
|
requires_grad=False) if need_act_scales else None
|
||||||
|
self.a2s_scale = nn.Parameter(
|
||||||
|
torch.zeros(1, device="cuda", dtype=torch.float32),
|
||||||
|
requires_grad=False) if need_act_scales else None
|
||||||
|
|
||||||
|
if need_act_scales:
|
||||||
|
set_weight_attrs(self.as_scale, {
|
||||||
|
"weight_loader": self.weight_loader,
|
||||||
|
})
|
||||||
|
set_weight_attrs(self.a2s_scale, {
|
||||||
|
"weight_loader": self.weight_loader,
|
||||||
|
})
|
||||||
|
|
||||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
|
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
|
||||||
weight_name: str, expert_id: int):
|
weight_name: str, expert_id: int):
|
||||||
@ -135,6 +153,8 @@ class MixtralMoE(nn.Module):
|
|||||||
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
|
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
|
||||||
if weight_name.endswith("w2.weight"):
|
if weight_name.endswith("w2.weight"):
|
||||||
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
param_data[expert_id, :, :] = loaded_weight[:, shard]
|
||||||
|
if "act_scale" in weight_name:
|
||||||
|
param_data[:] = param_data[:].max(loaded_weight)
|
||||||
|
|
||||||
def process_weights_after_loading(self):
|
def process_weights_after_loading(self):
|
||||||
if self.use_fp8:
|
if self.use_fp8:
|
||||||
@ -162,7 +182,9 @@ class MixtralMoE(nn.Module):
|
|||||||
inplace=True,
|
inplace=True,
|
||||||
use_fp8=self.use_fp8,
|
use_fp8=self.use_fp8,
|
||||||
w1_scale=self.ws_scale,
|
w1_scale=self.ws_scale,
|
||||||
w2_scale=self.w2s_scale)
|
w2_scale=self.w2s_scale,
|
||||||
|
a1_scale=self.as_scale,
|
||||||
|
a2_scale=self.a2s_scale)
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
@ -443,11 +465,19 @@ class MixtralForCausalLM(nn.Module):
|
|||||||
]
|
]
|
||||||
|
|
||||||
expert_params_mapping = [
|
expert_params_mapping = [
|
||||||
|
# These are the weights for the experts
|
||||||
# (param_name, weight_name, expert_id)
|
# (param_name, weight_name, expert_id)
|
||||||
("ws" if weight_name in ["w1", "w3"] else "w2s",
|
("ws" if weight_name in ["w1", "w3"] else "w2s",
|
||||||
f"experts.{expert_id}.{weight_name}.weight", expert_id)
|
f"experts.{expert_id}.{weight_name}.weight", expert_id)
|
||||||
for expert_id in range(self.config.num_local_experts)
|
for expert_id in range(self.config.num_local_experts)
|
||||||
for weight_name in ["w1", "w2", "w3"]
|
for weight_name in ["w1", "w2", "w3"]
|
||||||
|
] + [
|
||||||
|
# These are the activation scales for the experts
|
||||||
|
# (param_name, weight_name, expert_id)
|
||||||
|
("as_scale" if weight_name in ["w1", "w3"] else "a2s_scale",
|
||||||
|
f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
|
||||||
|
for expert_id in range(self.config.num_local_experts)
|
||||||
|
for weight_name in ["w1", "w2", "w3"]
|
||||||
]
|
]
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user