From debd6bbf0951222c9ad7a1a91b958027f2ed0782 Mon Sep 17 00:00:00 2001 From: Pavani Majety Date: Tue, 11 Mar 2025 22:13:11 -0700 Subject: [PATCH] [Kernel] Add ModelOpt FP4 Checkpoint Support (#12520) Signed-off-by: Pavani Majety --- csrc/ops.h | 8 +- .../quantization/fp4/nvfp4_scaled_mm_entry.cu | 6 + .../fp4/nvfp4_scaled_mm_kernels.cu | 7 +- csrc/torch_bindings.cpp | 4 + .../decoder_only/language/test_nvfp4.py | 82 ++++++ vllm/_custom_ops.py | 4 + vllm/config.py | 2 +- vllm/model_executor/layers/linear.py | 23 +- .../layers/quantization/__init__.py | 4 +- .../layers/quantization/modelopt.py | 278 +++++++++++++++++- 10 files changed, 388 insertions(+), 30 deletions(-) create mode 100644 tests/models/decoder_only/language/test_nvfp4.py diff --git a/csrc/ops.h b/csrc/ops.h index 724d7c92..7434aead 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -160,14 +160,16 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W, int64_t ggml_moe_get_block_size(int64_t type); #ifndef USE_ROCM + +bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability); +bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); +bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability); + void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A, torch::Tensor const& B, torch::Tensor const& A_sf, torch::Tensor const& B_sf, torch::Tensor const& alpha); -bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); -bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability); - void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales, diff --git a/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu b/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu index 7b57b32f..61b75e92 100644 --- a/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu +++ b/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu @@ -36,3 +36,9 @@ void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A, "be compiled using CUDA 12.8 and target " "compute capability 100 or above."); } + +bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) { + int runtimeVersion; + cudaRuntimeGetVersion(&runtimeVersion); + return cuda_device_capability >= 100 && runtimeVersion >= 12080; +} \ No newline at end of file diff --git a/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu b/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu index 9b30e4fe..6e14de0c 100644 --- a/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu @@ -201,10 +201,11 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B, #endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) #define CHECK_TYPE(x, st, m) \ - TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m) -#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor") + TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m) +#define CHECK_TH_CUDA(x, m) \ + TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor") #define CHECK_CONTIGUOUS(x, m) \ - TORCH_CHECK(x.is_contiguous(), m, "must be contiguous") + TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous") #define CHECK_INPUT(x, st, m) \ CHECK_TH_CUDA(x, m); \ CHECK_CONTIGUOUS(x, m); \ diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index eac27e64..d3bcb86a 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -434,6 +434,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor! output_scale, Tensor input_scale) -> ()"); ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant); + // Check if cutlass_scaled_mm_fp4 is supported for CUDA devices + // of the given capability + ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool"); + ops.impl("cutlass_scaled_mm_supports_fp4", &cutlass_scaled_mm_supports_fp4); #endif // Quantized GEMM for GPTQ. diff --git a/tests/models/decoder_only/language/test_nvfp4.py b/tests/models/decoder_only/language/test_nvfp4.py new file mode 100644 index 00000000..442e8e93 --- /dev/null +++ b/tests/models/decoder_only/language/test_nvfp4.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: Apache-2.0 +# flake8: noqa +"""Tests Model Optimizer nvfp4 models against ground truth generation +Note: these tests will only pass on B200 +""" +import os +from typing import List + +import pytest +from transformers import AutoTokenizer + +from tests.quantization.utils import is_quant_method_supported +from vllm import LLM, SamplingParams + +os.environ["TOKENIZERS_PARALLELISM"] = "true" + +MAX_MODEL_LEN = 1024 + +MODELS = ["nvidia/Llama-3.3-70B-Instruct-FP4"] + +EXPECTED_STRS_MAP = { + "nvidia/Llama-3.3-70B-Instruct-FP4": [ + 'vLLM (Vectorized Large Language Model) is indeed a high-throughput and memory-efficient inference', + 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', + 'Artificial intelligence (AI) and human intelligence (HI) are two distinct forms of intelligence that process', + 'A neural network is a type of machine learning model inspired by the structure and function of the human brain', + 'In the heart of a cutting-edge robotics lab, a team of engineers had been working tirelessly to push', + 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models, leading', + 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', + 'Here are the translations:\n\n* Japanese: (Sasuga no tori ga miwa o ts' + ] +} + + +# This test compares against golden strings for exact match since +# there is no baseline implementation to compare against +# and is unstable w.r.t specifics of the fp4 implementation or +# the hardware being run on. +# Disabled to prevent it from breaking the build +@pytest.mark.skip( + reason= + "Prevent unstable test based on golden strings from breaking the build " + " and test input model being too large and hanging the system.") +@pytest.mark.quant_model +@pytest.mark.skipif(not is_quant_method_supported("nvfp4"), + reason="nvfp4 is not supported on this GPU type.") +@pytest.mark.parametrize("model_name", MODELS) +def test_models(example_prompts, model_name) -> None: + model = LLM( + model=model_name, + max_model_len=MAX_MODEL_LEN, + trust_remote_code=True, + enforce_eager=True, + quantization="nvfp4", + ) + + tokenizer = AutoTokenizer.from_pretrained(model_name) + formatted_prompts = [ + tokenizer.apply_chat_template([{ + "role": "user", + "content": prompt + }], + tokenize=False, + add_generation_prompt=True) + for prompt in example_prompts + ] + params = SamplingParams(max_tokens=20, temperature=0) + generations: List[str] = [] + # Note: these need to be run 1 at a time due to numerical precision, + # since the expected strs were generated this way. + for prompt in formatted_prompts: + outputs = model.generate(prompt, params) + generations.append(outputs[0].outputs[0].text) + del model + + print(model_name, generations) + expected_strs = EXPECTED_STRS_MAP[model_name] + for i in range(len(example_prompts)): + generated_str = generations[i] + expected_str = expected_strs[i] + assert expected_str == generated_str, ( + f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}") diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 9f5b4871..64175cc4 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -467,6 +467,10 @@ if hasattr(torch.ops._C, "ggml_dequantize"): # cutlass +def cutlass_scaled_mm_supports_fp4(cuda_device_capability: int) -> bool: + return torch.ops._C.cutlass_scaled_mm_supports_fp4(cuda_device_capability) + + def cutlass_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor, block_scale_a: torch.Tensor, block_scale_b: torch.Tensor, alpha: torch.Tensor, diff --git a/vllm/config.py b/vllm/config.py index 26c02563..a0f30d0e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -613,7 +613,7 @@ class ModelConfig: optimized_quantization_methods = [ "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", "awq_marlin", "fbgemm_fp8", "compressed_tensors", - "compressed-tensors", "experts_int8", "quark" + "compressed-tensors", "experts_int8", "quark", "nvfp4" ] if self.quantization is not None: self.quantization = self.quantization.lower() diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 3912c53e..1ae57407 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -30,12 +30,23 @@ from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) WEIGHT_LOADER_V2_SUPPORTED = [ - "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", - "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod", - "MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod", - "TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod", - "ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod", - "HQQMarlinMethod", "QuarkLinearMethod" + "CompressedTensorsLinearMethod", + "AWQMarlinLinearMethod", + "AWQLinearMethod", + "GPTQMarlinLinearMethod", + "Fp8LinearMethod", + "MarlinLinearMethod", + "QQQLinearMethod", + "GPTQMarlin24LinearMethod", + "TPUInt8LinearMethod", + "GPTQLinearMethod", + "FBGEMMFp8LinearMethod", + "ModelOptFp8LinearMethod", + "IPEXAWQLinearMethod", + "IPEXGPTQLinearMethod", + "HQQMarlinMethod", + "QuarkLinearMethod", + "ModelOptNvFp4LinearMethod", ] diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 6cd508d0..a4dc4e9c 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -14,6 +14,7 @@ QUANTIZATION_METHODS: List[str] = [ "ptpc_fp8", "fbgemm_fp8", "modelopt", + "nvfp4", # The order of gptq methods is important for config.py iteration over # override_quantization_method(..) "marlin", @@ -97,7 +98,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: from .hqq_marlin import HQQMarlinConfig from .ipex_quant import IPEXConfig from .marlin import MarlinConfig - from .modelopt import ModelOptFp8Config + from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config from .moe_wna16 import MoeWNA16Config from .neuron_quant import NeuronQuantConfig from .ptpc_fp8 import PTPCFp8Config @@ -112,6 +113,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: "fp8": Fp8Config, "fbgemm_fp8": FBGEMMFp8Config, "modelopt": ModelOptFp8Config, + "nvfp4": ModelOptNvFp4Config, # The order of gptq methods is important for config.py iteration over # override_quantization_method(..) "marlin": MarlinConfig, diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 1f8af8d6..3de15369 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1,24 +1,31 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union import torch from torch.nn import Module from torch.nn.parameter import Parameter +from vllm._custom_ops import (cutlass_scaled_fp4_mm, + cutlass_scaled_mm_supports_fp4, scaled_fp4_quant) from vllm.logger import init_logger -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, requantize_with_max_scale) from vllm.model_executor.parameter import (ModelWeightParameter, PerTensorScaleParameter) +from vllm.platforms import current_platform logger = init_logger(__name__) -ACTIVATION_SCHEMES = ["static"] +QUANT_ALGOS = ["FP8", "NVFP4"] +KV_CACHE_QUANT_ALGOS = ["FP8"] class ModelOptFp8Config(QuantizationConfig): @@ -54,12 +61,13 @@ class ModelOptFp8Config(QuantizationConfig): def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config": quant_config = cls.get_from_keys(config, ["quantization"]) quant_method = quant_config["quant_algo"] - is_checkpoint_fp8_serialized = ("FP8" in quant_method) - if not is_checkpoint_fp8_serialized: - raise ValueError("ModelOpt currently only supports static FP8 " - "quantization in vLLM. Please check the " + if quant_method not in QUANT_ALGOS: + raise ValueError(f"ModelOpt currently only supports: {QUANT_ALGOS}" + " quantizations in vLLM. Please check the " "`hf_quant_config.json` file for your model's " "quant configuration.") + is_checkpoint_fp8_serialized = ("FP8" in quant_method) + return cls(is_checkpoint_fp8_serialized) def get_quant_method(self, layer: torch.nn.Module, @@ -72,15 +80,6 @@ class ModelOptFp8Config(QuantizationConfig): return None -class ModelOptFp8KVCacheMethod(BaseKVCacheMethod): - """ - Supports loading kv-cache scaling factors from FP8 checkpoints. - """ - - def __init__(self, quant_config: ModelOptFp8Config): - super().__init__(quant_config) - - class ModelOptFp8LinearMethod(LinearMethodBase): """Linear method for Model Optimizer static quantization. Supports loading FP8 checkpoints with static weight scale and @@ -162,3 +161,250 @@ class ModelOptFp8LinearMethod(LinearMethodBase): weight_scale=layer.weight_scale, input_scale=layer.input_scale, bias=bias) + + +class ModelOptNvFp4Config(QuantizationConfig): + """Config class for ModelOpt FP4.""" + + def __init__( + self, + is_checkpoint_nvfp4_serialized: bool, + kv_cache_quant_algo: str, + exclude_modules: List[str], + group_size: int = 16, + ) -> None: + self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized + if is_checkpoint_nvfp4_serialized: + logger.warning( + "Detected ModelOpt NVFP4 checkpoint. Please note that" + " the format is experimental and could change in future.") + + self.group_size = group_size + self.kv_cache_quant_algo = kv_cache_quant_algo + self.exclude_modules = exclude_modules + + @classmethod + def get_name(cls) -> str: + return "modelopt_nvfp4" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half, torch.float8_e4m3fn] + + @classmethod + def get_min_capability(cls) -> int: + return 100 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["hf_quant_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "ModelOptNvFp4Config": + quant_config = cls.get_from_keys(config, ["quantization"]) + quant_method = quant_config["quant_algo"] + if quant_method not in QUANT_ALGOS: + raise ValueError(f"ModelOpt currently only supports: {QUANT_ALGOS}" + " quantizations in vLLM. Please check the " + "`hf_quant_config.json` file for your model's " + "quant configuration.") + is_checkpoint_nvfp4_serialized = ("NVFP4" in quant_method) + kv_cache_quant_algo = quant_config["kv_cache_quant_algo"] + group_size = quant_config["group_size"] + exclude_modules = quant_config["exclude_modules"] + if not (group_size and kv_cache_quant_algo and exclude_modules): + raise ValueError("NVFP4 quantization requires group size and " + "kv_cache_quant_algo specified in " + "hf_quant_config.json") + return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo, + exclude_modules, group_size) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + if isinstance(layer, LinearBase): + if is_layer_skipped(prefix, self.exclude_modules): + return UnquantizedLinearMethod() + return ModelOptNvFp4LinearMethod(self) + elif isinstance(layer, Attention): + return ModelOptFp8KVCacheMethod(self) + return None + + +def cutlass_fp4_supported() -> bool: + if not current_platform.is_cuda(): + return False + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() + return cutlass_scaled_mm_supports_fp4(capability) + + +class ModelOptFp8KVCacheMethod(BaseKVCacheMethod): + """ + Supports loading kv-cache scaling factors from FP8 checkpoints. + """ + + def __init__(self, quant_config: Union[ModelOptFp8Config, + ModelOptNvFp4Config]): + super().__init__(quant_config) + + +class ModelOptNvFp4LinearMethod(LinearMethodBase): + """Linear method for Model Optimizer NVFP4. + Supports loading NVFP4 checkpoints with the following structure: + + input_scale: torch.float32, scalar , + weight: NVFP4(represented as byte) Shape: [1, X, y/2] + weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale, + weight_scale_2: torch.float32, scalar, + Args: quant_config: The ModelOpt quantization config. + """ + + def __init__(self, quant_config: ModelOptNvFp4Config): + self.quant_config = quant_config + self.cutlass_nvfp4_supported = cutlass_fp4_supported() + if not self.cutlass_nvfp4_supported: + raise ValueError("Current platform does not support NVFP4" + " quantization. Please use Blackwell and above.") + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + if not self.quant_config.is_checkpoint_nvfp4_serialized: + raise ValueError("NVFP4 quantization was selected, " + " dynamic quantization is not supported.") + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + if (input_size_per_partition % 16 != 0): + raise ValueError("Unsupported model when in features size is " + "not multiple of 16") + # The nvfp4 weight is still represented as + weight_dtype = (torch.float8_e4m3fn + if self.quant_config.is_checkpoint_nvfp4_serialized + else params_dtype) + # Weight + weight = ModelWeightParameter( + data=torch.empty( + # 2 fp4 items are packed in the input dimension + layer.output_size_per_partition, + layer.input_size_per_partition // 2, + dtype=torch.uint8), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight", weight) + + # Input Weight Scale + input_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("input_scale", input_scale) + + # Global Weight Scale + weight_scale_2 = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("weight_scale_2", weight_scale_2) + + # Per Block Weight Scale + weight_scale = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition // self.quant_config.group_size, + dtype=weight_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + layer.register_parameter("weight_scale", weight_scale) + + def swizzle_blockscale(self, scale: torch.tensor): + assert (scale.dtype == torch.float8_e4m3fn) + # Pad and blockwise interleave weight_scale + scale_ndim = scale.ndim + if scale.ndim == 2: + scale = scale.unsqueeze(0) + assert scale.ndim == 3 + B, M, K = scale.shape + round_up_multiple = lambda x, m: (x + m - 1) // m * m + M_padded = round_up_multiple(M, 128) + K_padded = round_up_multiple(K, 4) + padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype) + padded_scale[:B, :M, :K] = scale + batches, rows, cols = padded_scale.shape + assert rows % 128 == 0 + assert cols % 4 == 0 + padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32, + cols // 4, 4) + swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5)) + swizzled_scale = swizzled_scale.contiguous().cuda() + return (swizzled_scale.reshape(M, K) + if scale_ndim == 2 else swizzled_scale.reshape(B, M, K)) + + def process_weights_after_loading(self, layer: Module) -> None: + + # global scales: + input_scale_2 = layer.input_scale.max().to(torch.float32) + layer.input_scale = Parameter(input_scale_2, requires_grad=False) + + weight_scale_2 = layer.weight_scale_2.max().to(torch.float32) + layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False) + + layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2, + requires_grad=False) + + # Swizzle the weight blockscale. + # contracting dimension is input dimension + # block_size = 16; + assert (layer.weight_scale.shape[1] % 16 == 0), ( + "Expected weight_scale.dim(1) to be divisible by 16") + assert (layer.weight_scale.dtype == torch.float8_e4m3fn), ( + "Weight Block scale must be represented as FP8-E4M3") + swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale) + + layer.weight_scale_swizzled = Parameter(swizzled_weight_scale, + requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + output_dtype = x.dtype + + # for input only the contracting dimension has a constraint. + x_m, _ = x.shape + w_n, _ = layer.weight.shape + output_shape = [x_m, w_n] + + # quantize BF16 or FP16 to (FP4 and interleaved block scale) + s_quant = 1 / layer.input_scale + x_fp4, x_blockscale = scaled_fp4_quant(x, s_quant) + + # validate dtypes of quantized input, input block scale, + # weight and weight_blockscale + assert (x_fp4.dtype == torch.uint8) + assert (layer.weight.dtype == torch.uint8) + assert (x_blockscale.dtype == torch.float8_e4m3fn) + assert (layer.weight_scale_swizzled.dtype == torch.float8_e4m3fn) + assert (layer.alpha.dtype == torch.float32) + + out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale, + layer.weight_scale_swizzled, layer.alpha, + output_dtype) + if bias is not None: + out = out + bias + return out.view(*output_shape)