[Kernel] Add ModelOpt FP4 Checkpoint Support (#12520)

Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
Pavani Majety 2025-03-11 22:13:11 -07:00 committed by GitHub
parent 5c538c37b2
commit debd6bbf09
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 388 additions and 30 deletions

View File

@ -160,14 +160,16 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W,
int64_t ggml_moe_get_block_size(int64_t type);
#ifndef USE_ROCM
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability);
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
torch::Tensor const& B, torch::Tensor const& A_sf,
torch::Tensor const& B_sf,
torch::Tensor const& alpha);
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability);
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales,

View File

@ -36,3 +36,9 @@ void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A,
"be compiled using CUDA 12.8 and target "
"compute capability 100 or above.");
}
bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) {
int runtimeVersion;
cudaRuntimeGetVersion(&runtimeVersion);
return cuda_device_capability >= 100 && runtimeVersion >= 12080;
}

View File

@ -201,10 +201,11 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
#define CHECK_TYPE(x, st, m) \
TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor")
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
#define CHECK_TH_CUDA(x, m) \
TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x, m) \
TORCH_CHECK(x.is_contiguous(), m, "must be contiguous")
TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous")
#define CHECK_INPUT(x, st, m) \
CHECK_TH_CUDA(x, m); \
CHECK_CONTIGUOUS(x, m); \

View File

@ -434,6 +434,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor! output_scale, Tensor input_scale) -> ()");
ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
// Check if cutlass_scaled_mm_fp4 is supported for CUDA devices
// of the given capability
ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool");
ops.impl("cutlass_scaled_mm_supports_fp4", &cutlass_scaled_mm_supports_fp4);
#endif
// Quantized GEMM for GPTQ.

View File

@ -0,0 +1,82 @@
# SPDX-License-Identifier: Apache-2.0
# flake8: noqa
"""Tests Model Optimizer nvfp4 models against ground truth generation
Note: these tests will only pass on B200
"""
import os
from typing import List
import pytest
from transformers import AutoTokenizer
from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
os.environ["TOKENIZERS_PARALLELISM"] = "true"
MAX_MODEL_LEN = 1024
MODELS = ["nvidia/Llama-3.3-70B-Instruct-FP4"]
EXPECTED_STRS_MAP = {
"nvidia/Llama-3.3-70B-Instruct-FP4": [
'vLLM (Vectorized Large Language Model) is indeed a high-throughput and memory-efficient inference',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) are two distinct forms of intelligence that process',
'A neural network is a type of machine learning model inspired by the structure and function of the human brain',
'In the heart of a cutting-edge robotics lab, a team of engineers had been working tirelessly to push',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models, leading',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n* Japanese: (Sasuga no tori ga miwa o ts'
]
}
# This test compares against golden strings for exact match since
# there is no baseline implementation to compare against
# and is unstable w.r.t specifics of the fp4 implementation or
# the hardware being run on.
# Disabled to prevent it from breaking the build
@pytest.mark.skip(
reason=
"Prevent unstable test based on golden strings from breaking the build "
" and test input model being too large and hanging the system.")
@pytest.mark.quant_model
@pytest.mark.skipif(not is_quant_method_supported("nvfp4"),
reason="nvfp4 is not supported on this GPU type.")
@pytest.mark.parametrize("model_name", MODELS)
def test_models(example_prompts, model_name) -> None:
model = LLM(
model=model_name,
max_model_len=MAX_MODEL_LEN,
trust_remote_code=True,
enforce_eager=True,
quantization="nvfp4",
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
formatted_prompts = [
tokenizer.apply_chat_template([{
"role": "user",
"content": prompt
}],
tokenize=False,
add_generation_prompt=True)
for prompt in example_prompts
]
params = SamplingParams(max_tokens=20, temperature=0)
generations: List[str] = []
# Note: these need to be run 1 at a time due to numerical precision,
# since the expected strs were generated this way.
for prompt in formatted_prompts:
outputs = model.generate(prompt, params)
generations.append(outputs[0].outputs[0].text)
del model
print(model_name, generations)
expected_strs = EXPECTED_STRS_MAP[model_name]
for i in range(len(example_prompts)):
generated_str = generations[i]
expected_str = expected_strs[i]
assert expected_str == generated_str, (
f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}")

View File

@ -467,6 +467,10 @@ if hasattr(torch.ops._C, "ggml_dequantize"):
# cutlass
def cutlass_scaled_mm_supports_fp4(cuda_device_capability: int) -> bool:
return torch.ops._C.cutlass_scaled_mm_supports_fp4(cuda_device_capability)
def cutlass_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor,
block_scale_a: torch.Tensor,
block_scale_b: torch.Tensor, alpha: torch.Tensor,

View File

@ -613,7 +613,7 @@ class ModelConfig:
optimized_quantization_methods = [
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
"awq_marlin", "fbgemm_fp8", "compressed_tensors",
"compressed-tensors", "experts_int8", "quark"
"compressed-tensors", "experts_int8", "quark", "nvfp4"
]
if self.quantization is not None:
self.quantization = self.quantization.lower()

View File

@ -30,12 +30,23 @@ from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
"ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod",
"HQQMarlinMethod", "QuarkLinearMethod"
"CompressedTensorsLinearMethod",
"AWQMarlinLinearMethod",
"AWQLinearMethod",
"GPTQMarlinLinearMethod",
"Fp8LinearMethod",
"MarlinLinearMethod",
"QQQLinearMethod",
"GPTQMarlin24LinearMethod",
"TPUInt8LinearMethod",
"GPTQLinearMethod",
"FBGEMMFp8LinearMethod",
"ModelOptFp8LinearMethod",
"IPEXAWQLinearMethod",
"IPEXGPTQLinearMethod",
"HQQMarlinMethod",
"QuarkLinearMethod",
"ModelOptNvFp4LinearMethod",
]

View File

@ -14,6 +14,7 @@ QUANTIZATION_METHODS: List[str] = [
"ptpc_fp8",
"fbgemm_fp8",
"modelopt",
"nvfp4",
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin",
@ -97,7 +98,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
from .hqq_marlin import HQQMarlinConfig
from .ipex_quant import IPEXConfig
from .marlin import MarlinConfig
from .modelopt import ModelOptFp8Config
from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config
from .moe_wna16 import MoeWNA16Config
from .neuron_quant import NeuronQuantConfig
from .ptpc_fp8 import PTPCFp8Config
@ -112,6 +113,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
"fp8": Fp8Config,
"fbgemm_fp8": FBGEMMFp8Config,
"modelopt": ModelOptFp8Config,
"nvfp4": ModelOptNvFp4Config,
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin": MarlinConfig,

View File

@ -1,24 +1,31 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from vllm._custom_ops import (cutlass_scaled_fp4_mm,
cutlass_scaled_mm_supports_fp4, scaled_fp4_quant)
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, requantize_with_max_scale)
from vllm.model_executor.parameter import (ModelWeightParameter,
PerTensorScaleParameter)
from vllm.platforms import current_platform
logger = init_logger(__name__)
ACTIVATION_SCHEMES = ["static"]
QUANT_ALGOS = ["FP8", "NVFP4"]
KV_CACHE_QUANT_ALGOS = ["FP8"]
class ModelOptFp8Config(QuantizationConfig):
@ -54,12 +61,13 @@ class ModelOptFp8Config(QuantizationConfig):
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
quant_config = cls.get_from_keys(config, ["quantization"])
quant_method = quant_config["quant_algo"]
is_checkpoint_fp8_serialized = ("FP8" in quant_method)
if not is_checkpoint_fp8_serialized:
raise ValueError("ModelOpt currently only supports static FP8 "
"quantization in vLLM. Please check the "
if quant_method not in QUANT_ALGOS:
raise ValueError(f"ModelOpt currently only supports: {QUANT_ALGOS}"
" quantizations in vLLM. Please check the "
"`hf_quant_config.json` file for your model's "
"quant configuration.")
is_checkpoint_fp8_serialized = ("FP8" in quant_method)
return cls(is_checkpoint_fp8_serialized)
def get_quant_method(self, layer: torch.nn.Module,
@ -72,15 +80,6 @@ class ModelOptFp8Config(QuantizationConfig):
return None
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
"""
Supports loading kv-cache scaling factors from FP8 checkpoints.
"""
def __init__(self, quant_config: ModelOptFp8Config):
super().__init__(quant_config)
class ModelOptFp8LinearMethod(LinearMethodBase):
"""Linear method for Model Optimizer static quantization.
Supports loading FP8 checkpoints with static weight scale and
@ -162,3 +161,250 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias)
class ModelOptNvFp4Config(QuantizationConfig):
"""Config class for ModelOpt FP4."""
def __init__(
self,
is_checkpoint_nvfp4_serialized: bool,
kv_cache_quant_algo: str,
exclude_modules: List[str],
group_size: int = 16,
) -> None:
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
if is_checkpoint_nvfp4_serialized:
logger.warning(
"Detected ModelOpt NVFP4 checkpoint. Please note that"
" the format is experimental and could change in future.")
self.group_size = group_size
self.kv_cache_quant_algo = kv_cache_quant_algo
self.exclude_modules = exclude_modules
@classmethod
def get_name(cls) -> str:
return "modelopt_nvfp4"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16, torch.half, torch.float8_e4m3fn]
@classmethod
def get_min_capability(cls) -> int:
return 100
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["hf_quant_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ModelOptNvFp4Config":
quant_config = cls.get_from_keys(config, ["quantization"])
quant_method = quant_config["quant_algo"]
if quant_method not in QUANT_ALGOS:
raise ValueError(f"ModelOpt currently only supports: {QUANT_ALGOS}"
" quantizations in vLLM. Please check the "
"`hf_quant_config.json` file for your model's "
"quant configuration.")
is_checkpoint_nvfp4_serialized = ("NVFP4" in quant_method)
kv_cache_quant_algo = quant_config["kv_cache_quant_algo"]
group_size = quant_config["group_size"]
exclude_modules = quant_config["exclude_modules"]
if not (group_size and kv_cache_quant_algo and exclude_modules):
raise ValueError("NVFP4 quantization requires group size and "
"kv_cache_quant_algo specified in "
"hf_quant_config.json")
return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
exclude_modules, group_size)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.exclude_modules):
return UnquantizedLinearMethod()
return ModelOptNvFp4LinearMethod(self)
elif isinstance(layer, Attention):
return ModelOptFp8KVCacheMethod(self)
return None
def cutlass_fp4_supported() -> bool:
if not current_platform.is_cuda():
return False
capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int()
return cutlass_scaled_mm_supports_fp4(capability)
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
"""
Supports loading kv-cache scaling factors from FP8 checkpoints.
"""
def __init__(self, quant_config: Union[ModelOptFp8Config,
ModelOptNvFp4Config]):
super().__init__(quant_config)
class ModelOptNvFp4LinearMethod(LinearMethodBase):
"""Linear method for Model Optimizer NVFP4.
Supports loading NVFP4 checkpoints with the following structure:
input_scale: torch.float32, scalar ,
weight: NVFP4(represented as byte) Shape: [1, X, y/2]
weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
weight_scale_2: torch.float32, scalar,
Args: quant_config: The ModelOpt quantization config.
"""
def __init__(self, quant_config: ModelOptNvFp4Config):
self.quant_config = quant_config
self.cutlass_nvfp4_supported = cutlass_fp4_supported()
if not self.cutlass_nvfp4_supported:
raise ValueError("Current platform does not support NVFP4"
" quantization. Please use Blackwell and above.")
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del input_size, output_size
if not self.quant_config.is_checkpoint_nvfp4_serialized:
raise ValueError("NVFP4 quantization was selected, "
" dynamic quantization is not supported.")
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
if (input_size_per_partition % 16 != 0):
raise ValueError("Unsupported model when in features size is "
"not multiple of 16")
# The nvfp4 weight is still represented as
weight_dtype = (torch.float8_e4m3fn
if self.quant_config.is_checkpoint_nvfp4_serialized
else params_dtype)
# Weight
weight = ModelWeightParameter(
data=torch.empty(
# 2 fp4 items are packed in the input dimension
layer.output_size_per_partition,
layer.input_size_per_partition // 2,
dtype=torch.uint8),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
# Input Weight Scale
input_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale)
# Global Weight Scale
weight_scale_2 = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("weight_scale_2", weight_scale_2)
# Per Block Weight Scale
weight_scale = ModelWeightParameter(data=torch.empty(
output_size_per_partition,
input_size_per_partition // self.quant_config.group_size,
dtype=weight_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight_scale", weight_scale)
def swizzle_blockscale(self, scale: torch.tensor):
assert (scale.dtype == torch.float8_e4m3fn)
# Pad and blockwise interleave weight_scale
scale_ndim = scale.ndim
if scale.ndim == 2:
scale = scale.unsqueeze(0)
assert scale.ndim == 3
B, M, K = scale.shape
round_up_multiple = lambda x, m: (x + m - 1) // m * m
M_padded = round_up_multiple(M, 128)
K_padded = round_up_multiple(K, 4)
padded_scale = torch.zeros((B, M_padded, K_padded), dtype=scale.dtype)
padded_scale[:B, :M, :K] = scale
batches, rows, cols = padded_scale.shape
assert rows % 128 == 0
assert cols % 4 == 0
padded_scale = padded_scale.reshape(batches, rows // 128, 4, 32,
cols // 4, 4)
swizzled_scale = padded_scale.permute((0, 1, 4, 3, 2, 5))
swizzled_scale = swizzled_scale.contiguous().cuda()
return (swizzled_scale.reshape(M, K)
if scale_ndim == 2 else swizzled_scale.reshape(B, M, K))
def process_weights_after_loading(self, layer: Module) -> None:
# global scales:
input_scale_2 = layer.input_scale.max().to(torch.float32)
layer.input_scale = Parameter(input_scale_2, requires_grad=False)
weight_scale_2 = layer.weight_scale_2.max().to(torch.float32)
layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False)
layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2,
requires_grad=False)
# Swizzle the weight blockscale.
# contracting dimension is input dimension
# block_size = 16;
assert (layer.weight_scale.shape[1] % 16 == 0), (
"Expected weight_scale.dim(1) to be divisible by 16")
assert (layer.weight_scale.dtype == torch.float8_e4m3fn), (
"Weight Block scale must be represented as FP8-E4M3")
swizzled_weight_scale = self.swizzle_blockscale(layer.weight_scale)
layer.weight_scale_swizzled = Parameter(swizzled_weight_scale,
requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
output_dtype = x.dtype
# for input only the contracting dimension has a constraint.
x_m, _ = x.shape
w_n, _ = layer.weight.shape
output_shape = [x_m, w_n]
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
s_quant = 1 / layer.input_scale
x_fp4, x_blockscale = scaled_fp4_quant(x, s_quant)
# validate dtypes of quantized input, input block scale,
# weight and weight_blockscale
assert (x_fp4.dtype == torch.uint8)
assert (layer.weight.dtype == torch.uint8)
assert (x_blockscale.dtype == torch.float8_e4m3fn)
assert (layer.weight_scale_swizzled.dtype == torch.float8_e4m3fn)
assert (layer.alpha.dtype == torch.float32)
out = cutlass_scaled_fp4_mm(x_fp4, layer.weight, x_blockscale,
layer.weight_scale_swizzled, layer.alpha,
output_dtype)
if bias is not None:
out = out + bias
return out.view(*output_shape)