Pavani Majety debd6bbf09
[Kernel] Add ModelOpt FP4 Checkpoint Support (#12520)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
2025-03-12 05:13:11 +00:00

147 lines
4.8 KiB
Python

# SPDX-License-Identifier: Apache-2.0
from typing import Dict, List, Type
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
QUANTIZATION_METHODS: List[str] = [
"aqlm",
"awq",
"deepspeedfp",
"tpu_int8",
"fp8",
"ptpc_fp8",
"fbgemm_fp8",
"modelopt",
"nvfp4",
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin",
"gguf",
"gptq_marlin_24",
"gptq_marlin",
"awq_marlin",
"gptq",
"compressed-tensors",
"bitsandbytes",
"qqq",
"hqq",
"experts_int8",
"neuron_quant",
"ipex",
"quark",
"moe_wna16"
]
# The customized quantization methods which will be added to this dict.
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG = {}
def register_quantization_config(quantization: str):
"""Register a customized vllm quantization config.
When a quantization method is not supported by vllm, you can register a customized
quantization config to support it.
Args:
quantization (str): The quantization method name.
Examples:
>>> from vllm.model_executor.layers.quantization import register_quantization_config
>>> from vllm.model_executor.layers.quantization import get_quantization_config
>>> from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
>>>
>>> @register_quantization_config("my_quant")
... class MyQuantConfig(QuantizationConfig):
... pass
>>>
>>> get_quantization_config("my_quant")
<class 'MyQuantConfig'>
""" # noqa: E501
def _wrapper(quant_config_cls):
if quantization in QUANTIZATION_METHODS:
raise ValueError(
f"The quantization method `{quantization}` is already exists.")
if not issubclass(quant_config_cls, QuantizationConfig):
raise ValueError("The quantization config must be a subclass of "
"`QuantizationConfig`.")
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG[quantization] = quant_config_cls
QUANTIZATION_METHODS.append(quantization)
return quant_config_cls
return _wrapper
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
if quantization not in QUANTIZATION_METHODS:
raise ValueError(f"Invalid quantization method: {quantization}")
# lazy import to avoid triggering `torch.compile` too early
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
from .aqlm import AQLMConfig
from .awq import AWQConfig
from .awq_marlin import AWQMarlinConfig
from .bitsandbytes import BitsAndBytesConfig
from .compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsConfig)
from .deepspeedfp import DeepSpeedFPConfig
from .experts_int8 import ExpertsInt8Config
from .fbgemm_fp8 import FBGEMMFp8Config
from .fp8 import Fp8Config
from .gguf import GGUFConfig
from .gptq import GPTQConfig
from .gptq_marlin import GPTQMarlinConfig
from .gptq_marlin_24 import GPTQMarlin24Config
from .hqq_marlin import HQQMarlinConfig
from .ipex_quant import IPEXConfig
from .marlin import MarlinConfig
from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config
from .moe_wna16 import MoeWNA16Config
from .neuron_quant import NeuronQuantConfig
from .ptpc_fp8 import PTPCFp8Config
from .qqq import QQQConfig
from .tpu_int8 import Int8TpuConfig
method_to_config: Dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig,
"awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig,
"tpu_int8": Int8TpuConfig,
"fp8": Fp8Config,
"fbgemm_fp8": FBGEMMFp8Config,
"modelopt": ModelOptFp8Config,
"nvfp4": ModelOptNvFp4Config,
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin": MarlinConfig,
"gguf": GGUFConfig,
"gptq_marlin_24": GPTQMarlin24Config,
"gptq_marlin": GPTQMarlinConfig,
"awq_marlin": AWQMarlinConfig,
"gptq": GPTQConfig,
"compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig,
"ptpc_fp8": PTPCFp8Config,
"qqq": QQQConfig,
"hqq": HQQMarlinConfig,
"experts_int8": ExpertsInt8Config,
"neuron_quant": NeuronQuantConfig,
"ipex": IPEXConfig,
"quark": QuarkConfig,
"moe_wna16": MoeWNA16Config,
}
# Update the `method_to_config` with customized quantization methods.
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
return method_to_config[quantization]
__all__ = [
"QuantizationConfig",
"get_quantization_config",
"QUANTIZATION_METHODS",
]