[Misc][Quark] Upstream Quark format to VLLM (#10765)

Signed-off-by: kewang-xlnx <kewang@xilinx.com>
Signed-off-by: kewang2 <kewang2@amd.com>
Co-authored-by: kewang2 <kewang2@amd.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
kewang-xlnx 2025-01-16 00:05:15 +08:00 committed by GitHub
parent 5ecf3e0aaf
commit de0526f668
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 1264 additions and 70 deletions

View File

@ -0,0 +1,30 @@
"""Test model set-up and weight loading for quark-quantized models.
Run `pytest tests/quantization/test_quark.py`.
"""
import torch
from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
QuarkLinearMethod, QuarkW8A8Fp8)
def test_quark_fp8(vllm_runner):
model_path = "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test"
with vllm_runner(model_path) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
assert isinstance(qkv_proj.scheme, QuarkW8A8Fp8)
if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
assert len(qkv_proj.input_scale.shape) == 0
assert qkv_proj.weight.dtype is torch.float8_e4m3fn
#assert qkv_proj.weight.dtype is torch.float8_e4m3fnuz
assert len(qkv_proj.weight_scale.shape) == 0
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output

View File

@ -553,7 +553,7 @@ class ModelConfig:
optimized_quantization_methods = [
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
"awq_marlin", "fbgemm_fp8", "compressed_tensors",
"compressed-tensors", "experts_int8"
"compressed-tensors", "experts_int8", "quark"
]
if self.quantization is not None:
self.quantization = self.quantization.lower()

View File

@ -32,7 +32,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
"ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod",
"HQQMarlinMethod"
"HQQMarlinMethod", "QuarkLinearMethod"
]

View File

@ -26,6 +26,7 @@ QUANTIZATION_METHODS: List[str] = [
"experts_int8",
"neuron_quant",
"ipex",
"quark"
]
@ -34,6 +35,8 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
raise ValueError(f"Invalid quantization method: {quantization}")
# lazy import to avoid triggering `torch.compile` too early
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
from .aqlm import AQLMConfig
from .awq import AWQConfig
from .awq_marlin import AWQMarlinConfig
@ -79,6 +82,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
"experts_int8": ExpertsInt8Config,
"neuron_quant": NeuronQuantConfig,
"ipex": IPEXConfig,
"quark": QuarkConfig
}
return method_to_config[quantization]

View File

@ -133,3 +133,6 @@ class QuantizationConfig(ABC):
method.
"""
raise NotImplementedError
def get_cache_scale(self, name: str) -> Optional[str]:
return None

View File

@ -412,6 +412,22 @@ class CompressedTensorsConfig(QuantizationConfig):
self._check_scheme_supported(scheme.get_min_capability())
return scheme
def get_cache_scale(self, name: str) -> Optional[str]:
"""
Check whether the param name matches the format for k/v cache scales
in compressed-tensors. If this is the case, return its equivalent
param name expected by vLLM
:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if name.endswith(".output_scale") and ".k_proj" in name:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
# If no matches, return None
return None
@staticmethod
def supports_cutlass_24(
weight_quant: Optional[QuantizationArgs],

View File

@ -136,6 +136,10 @@ def triton_scaled_mm(input: torch.Tensor,
assert N > 0 and K > 0 and M > 0
assert weight.shape[0] == K
assert input.dtype == weight.dtype
scale_a = scale_a.reshape(-1, 1) if scale_a.dim() <= 1 else scale_a
scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b
assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point()
assert scale_a.shape == torch.Size([1, 1]) or scale_a.shape == torch.Size(
[M, 1])

View File

@ -133,23 +133,6 @@ def _find_first_match(value: str,
return None
def get_compressed_tensors_cache_scale(name: str) -> Optional[str]:
"""
Check whether the param name matches the format for k/v cache scales
in compressed-tensors. If this is the case, return its equivalent
param name expected by vLLM
:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if name.endswith(".output_scale") and ".k_proj" in name:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
# If no matches, return None
return None
def _is_equal_or_regex_match(value: str,
target: str,
check_contains: bool = False) -> bool:

View File

@ -0,0 +1,387 @@
import fnmatch
import re
from typing import Any, Dict, List, Optional, cast
import torch
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
QuarkMoEMethod)
from vllm.model_executor.layers.quantization.quark.schemes import (
QuarkScheme, QuarkW8A8Fp8, QuarkW8A8Int8)
from vllm.model_executor.layers.quantization.quark.utils import (
deep_compare, should_ignore_layer)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)
from vllm.platforms import current_platform
__all__ = ["QuarkLinearMethod"]
class QuarkConfig(QuantizationConfig):
def __init__(self,
quant_config: Dict[str, Any],
kv_cache_group: Optional[List[str]] = None,
kv_cache_config: Optional[Dict[str, Any]] = None,
pack_method: str = "reorder"):
if kv_cache_group is None:
kv_cache_group = []
self.quant_config = quant_config
self.kv_cache_group = kv_cache_group
self.kv_cache_config = kv_cache_config
self.pack_method = pack_method
def get_linear_method(self) -> "QuarkLinearMethod":
return QuarkLinearMethod(self)
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 70
def get_name(self) -> str:
return "quark"
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
# Check if the layer is skipped for quantization.
exclude_layers = cast(List[str], self.quant_config.get("exclude"))
if should_ignore_layer(prefix, ignore=exclude_layers):
return UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix)
layer.scheme = scheme
return QuarkLinearMethod(self)
if isinstance(layer, Attention):
return QuarkKVCacheMethod(self)
if isinstance(layer, FusedMoE):
return QuarkMoEMethod.get_moe_method(self,
module=layer,
layer_name=prefix)
return None
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig":
export_config = config.get("export")
if export_config is None:
raise ValueError("The export key should be included in "
"the configurations of Quark quantized model")
kv_cache_group = cast(List[str], export_config.get("kv_cache_group"))
pack_method = cast(str, export_config.get("pack_method"))
# In the export model of quark, the quantization configuration
# of kv_cache is stored in layer_quant_config. First, it is
# judged whether kv_cache_group exists, and then it is judged
# whether layer_quant_config has a quantization configuration
# that matches kv_cache.
if len(kv_cache_group) == 0:
kv_cache_config = None
else:
kv_cache_set = set(kv_cache_group)
layer_quant_config = cast(Dict[str, Any],
config.get("layer_quant_config"))
layer_quant_names = list(layer_quant_config.keys())
layer_quant_set = set(layer_quant_names)
if not kv_cache_set.issubset(layer_quant_set):
raise ValueError("The Quark quantized model has the "
"kv_cache_group parameter setting, "
"but no kv_cache quantization settings "
"were found in the quantization "
"configuration.")
q_configs = [
cast(Dict[str, Any], layer_quant_config.get(name))
for name in kv_cache_group
]
if not all(
deep_compare(q_config, q_configs[0])
for q_config in q_configs):
raise ValueError(
"The quantization method used for kv_cache should "
"be the same, but the quantization method for the "
"kv_cache layer in the config is different.")
kv_cache_config = q_configs[0].get("output_tensors")
if kv_cache_config is None:
raise ValueError(
"The kv_cache quantization configuration is empty.")
# Since we have already set kv_cache quantization configurations,
# we will remove the quantization configuration for the
# output_tensors corresponding to the kv_cache layer.
for q_config in q_configs:
q_config["output_tensors"] = None
return cls(quant_config=config,
kv_cache_group=kv_cache_group,
kv_cache_config=kv_cache_config,
pack_method=pack_method)
@classmethod
def get_config_filenames(cls) -> List[str]:
return []
def _check_scheme_supported(self,
min_capability: int,
error: bool = True) -> bool:
capability_tuple = current_platform.get_device_capability()
if capability_tuple is not None:
capability = capability_tuple.to_int()
supported = capability >= min_capability
if error and not supported:
raise RuntimeError(
"Quantization scheme is not supported for ",
f"the current GPU. Min capability: {min_capability}. ",
f"Current capability: {capability}.")
return supported
else:
return False
def _is_fp8_w8a8(self, weight_quant: Optional[Dict[str, Any]],
input_quant: Optional[Dict[str, Any]]) -> bool:
# Confirm weights and input quantized.
if weight_quant is None or input_quant is None:
return False
# Confirm weight scheme is supported
is_fp8_dtype = (weight_quant.get("dtype") == "fp8_e4m3"
and input_quant.get("dtype") == "fp8_e4m3")
is_static_weight = not weight_quant.get("is_dynamic")
is_per_tensor_or_channel_weight = (weight_quant.get("qscheme")
in ["per_tensor", "per_channel"])
if not (is_fp8_dtype and is_static_weight
and is_per_tensor_or_channel_weight):
return False
# Dynamic quantization is always supported if weights supported.
if input_quant.get("is_dynamic"):
return True
# Confirm activation scheme is supported.
is_per_tensor_activation = (input_quant.get("qscheme") == "per_tensor")
return is_per_tensor_activation
def _is_static_tensor_w8a8(self, weight_quant: Optional[Dict[str, Any]],
input_quant: Optional[Dict[str, Any]]) -> bool:
# Confirm weights and input quantized.
if weight_quant is None or input_quant is None:
return False
is_int8_dtype = (weight_quant.get("dtype") == "int8"
and input_quant.get("dtype") == "int8")
is_tensor = (weight_quant.get("qscheme")
in ["per_tensor", "per_channel"]
and input_quant.get("qscheme") == "per_tensor")
is_static = (not weight_quant.get("is_dynamic")
and not input_quant.get("is_dynamic"))
is_weight_symmetric = (weight_quant.get("symmetric") is True)
# Both symmetric and asymmetric input quantization supported.
# Only symmetric weight quantization supported.
return is_int8_dtype and is_tensor and is_weight_symmetric and is_static
def _find_matched_config(self, layer_name: str,
module: torch.nn.Module) -> Dict[str, Any]:
proj_name = layer_name.split(".")[-1]
if proj_name in FUSED_LAYER_NAME_MAPPING:
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name]
# Convert fused_name --> [shard_names]
shard_names = [
layer_name.replace(proj_name, shard_proj_name)
for shard_proj_name in shard_proj_names
]
shard_configs = [
self._find_matched_config(shard_name, module)
for shard_name in shard_names
]
if not all(
deep_compare(q_config, shard_configs[0])
for q_config in shard_configs):
raise ValueError(
f"Found a different quantization configuration for "
f"{shard_proj_names} in {layer_name}. vLLM "
"requires all to use the same scheme.")
return shard_configs[0]
else:
layer_quant_config = cast(
Dict[str, Any], self.quant_config.get("layer_quant_config"))
for name_pattern in layer_quant_config:
if fnmatch.fnmatch(layer_name, name_pattern):
return layer_quant_config[name_pattern]
layer_type = cast(str, type(module))
layer_type_quant_config = cast(
Dict[str, Any],
self.quant_config.get("layer_type_quant_config"))
if layer_type in layer_type_quant_config:
return layer_type_quant_config[layer_type]
global_quant_config = cast(
Dict[str, Any], self.quant_config.get("global_quant_config"))
return global_quant_config
def _get_scheme_from_config(self, config: Dict[str, Any]) -> "QuarkScheme":
if config.get("output_tensors") or config.get("bias"):
raise NotImplementedError(
"Currently, Quark models with output_tensors "
"and bias quantized are not supported")
weight_config = cast(Dict[str, Any], config.get("weight"))
input_config = cast(Dict[str, Any], config.get("input_tensors"))
if self._is_fp8_w8a8(weight_config, input_config):
is_fp8_w8a8_supported = self._check_scheme_supported(
QuarkW8A8Fp8.get_min_capability(), error=False)
if is_fp8_w8a8_supported:
weight_qscheme = cast(str, weight_config.get("qscheme"))
input_static = (input_config is not None and
not cast(bool, input_config.get("is_dynamic")))
return QuarkW8A8Fp8(qscheme=weight_qscheme,
is_static_input_scheme=input_static)
elif self._is_static_tensor_w8a8(weight_config, input_config):
weight_qscheme = cast(str, weight_config.get("qscheme"))
return QuarkW8A8Int8(qscheme=weight_qscheme,
is_static_input_scheme=True,
input_symmetric=input_config.get("symmetric"))
raise NotImplementedError("No quark compatible scheme was found. "
f"Weight config: {weight_config}, "
f"Input config: {input_config}")
def get_scheme(self, layer: torch.nn.Module,
layer_name: str) -> "QuarkScheme":
layer_quant_config = self._find_matched_config(layer_name, layer)
# Find the quant_scheme
scheme = self._get_scheme_from_config(layer_quant_config)
# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
self._check_scheme_supported(scheme.get_min_capability())
return scheme
def get_cache_scale(self, name: str) -> Optional[str]:
"""
Check whether the param name matches the format for k/v cache scales
in quark. If this is the case, return its equivalent param name
expected by vLLM
:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if self.kv_cache_group is None or len(self.kv_cache_group) == 0:
return None
kv_proj_names = [
re.split(r"[*.]", kv_cache)[-1] for kv_cache in self.kv_cache_group
]
if name.endswith(".output_scale"):
if len(kv_proj_names) == 1 and kv_proj_names[0] in name:
kv_output_scale_name = "." + kv_proj_names[0] + ".output_scale"
return name.replace(kv_output_scale_name, ".attn.k_scale")
elif len(kv_proj_names) == 2:
for kv_proj_name in kv_proj_names:
if kv_proj_name in name and kv_proj_name == "k_proj":
return name.replace(".k_proj.output_scale",
".attn.k_scale")
elif kv_proj_name in name and kv_proj_name == "v_proj":
return name.replace(".v_proj.output_scale",
".attn.v_scale")
# If no matches, return None
return None
class QuarkLinearMethod(LinearMethodBase):
def __init__(self, quantization_config: QuarkConfig):
self.quantization_config = quantization_config
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.scheme.process_weights_after_loading(layer)
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
"""
Use the CompressedTensorsScheme associated with each layer to create
the necessary parameters for the layer. See LinearMethodBase for param
details
"""
weight_loader = extra_weight_attrs.get("weight_loader")
layer.scheme.create_weights(
layer=layer,
input_size=input_size,
input_size_per_partition=input_size_per_partition,
output_partition_sizes=output_partition_sizes,
output_size=output_size,
params_dtype=params_dtype,
weight_loader=weight_loader)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None):
"""
Use the output of create_weights and the CompressedTensorsScheme
associated with the layer to apply the forward pass with the
layer input. See LinearMethodBase for param details
"""
scheme = layer.scheme
if scheme is None:
raise ValueError("A scheme must be defined for each layer")
return scheme.apply_weights(layer, x, bias=bias)
class QuarkKVCacheMethod(BaseKVCacheMethod):
"""
Supports loading kv-cache scaling factors from quark checkpoints.
"""
def __init__(self, quant_config: QuarkConfig):
self.validate_kv_cache_config(quant_config.kv_cache_config)
super().__init__(quant_config)
@staticmethod
def validate_kv_cache_config(kv_cache_config: Optional[Dict[str, Any]]):
"""
Validator for the kv cache configuration. Useful for controlling the
kv cache quantization schemes, that are being supported in vLLM
:param kv_cache_config: the quark kv cache scheme
"""
if kv_cache_config is None:
return
dtype = kv_cache_config.get("dtype")
if dtype != "fp8_e4m3":
raise NotImplementedError(
"Currently supported kv cache quantization is "
f"dtype=fp8_e4m3, however received {dtype}")
qscheme = kv_cache_config.get("qscheme")
if qscheme != "per_tensor":
raise NotImplementedError(
"Only support per-tensor scaling factor "
"for quark KV cache. "
f"Expected qscheme: per_tensor, found qscheme: {qscheme}")

View File

@ -0,0 +1,225 @@
from typing import Any, Callable, Dict, Optional
import torch
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
logger = init_logger(__name__)
__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod"]
class QuarkMoEMethod(FusedMoEMethodBase):
@staticmethod
def get_moe_method(
quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821
module: torch.nn.Module,
layer_name: str) -> "QuarkMoEMethod":
layer_quant_config = quant_config._find_matched_config(
layer_name, module)
if (layer_quant_config.get("output_tensors")
or layer_quant_config.get("bias")):
raise NotImplementedError("Currently, Quark models with "
"output_tensors and bias "
"quantized are not supported")
weight_config = layer_quant_config.get("weight")
input_config = layer_quant_config.get("input_tensors")
if quant_config._is_fp8_w8a8(weight_config, input_config):
return QuarkW8A8Fp8MoEMethod(weight_config, input_config)
else:
raise RuntimeError("Unsupported FusedMoe scheme")
class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
def __init__(self, weight_config: Dict[str, Any], input_config: Dict[str,
Any]):
self.weight_quant = weight_config
self.input_quant = input_config
weight_qscheme = self.weight_quant.get("qscheme")
input_qscheme = self.input_quant.get("qscheme")
if not (weight_qscheme == "per_tensor"
and input_qscheme == "per_tensor"):
raise ValueError(
"For FP8 Fused MoE layers, only per-tensor scales"
"for weights and activations are supported. Found "
f"{weight_qscheme}, {input_qscheme}") # noqa E501
self.static_input_scales = not self.input_quant.get("is_dynamic")
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
params_dtype: torch.dtype, **extra_weight_attrs):
params_dtype = torch.float8_e4m3fn
# WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
2 * intermediate_size,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size,
intermediate_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
2,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES
if self.static_input_scales:
w13_input_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_input_scale", w13_input_scale)
set_weight_attrs(w13_input_scale, extra_weight_attrs)
w2_input_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, extra_weight_attrs)
else:
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if self.static_input_scales:
if (layer.w13_input_scale is None or layer.w2_input_scale is None):
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None.")
if (not all_close_1d(layer.w13_input_scale)
or not all_close_1d(layer.w2_input_scale)):
logger.warning_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer. ")
layer.w13_input_scale = torch.nn.Parameter(
layer.w13_input_scale.max(), requires_grad=False)
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False)
# If rocm, normalize the weights and scales to e4m3fnuz
if current_platform.is_rocm():
# Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
layer.w13_weight, layer.w13_weight_scale,
layer.w13_input_scale)
w2_weight, w2_weight_scale, w2_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
layer.w2_weight, layer.w2_weight_scale,
layer.w2_input_scale)
# Reset the parameter
layer.w13_weight = torch.nn.Parameter(w13_weight,
requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale,
requires_grad=False)
if w13_input_scale is not None:
layer.w13_input_scale = torch.nn.Parameter(w13_input_scale,
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight,
requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
requires_grad=False)
if w2_input_scale is not None:
layer.w2_input_scale = torch.nn.Parameter(w2_input_scale,
requires_grad=False)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start:start + shard_size, :],
layer.w13_weight_scale[expert_id][shard_id])
layer.w13_weight[expert_id][
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
dq_weight, max_w13_scales[expert_id])
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return fused_experts(x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_fp8_w8a8=True,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)

View File

@ -0,0 +1,5 @@
from .quark_scheme import QuarkScheme
from .quark_w8a8_fp8 import QuarkW8A8Fp8
from .quark_w8a8_int8 import QuarkW8A8Int8
__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8"]

View File

@ -0,0 +1,52 @@
from abc import ABC, abstractmethod
from typing import Optional
import torch
__all__ = ["QuarkScheme"]
class QuarkScheme(ABC):
"""
Abstract class used to describe the weight creation and forward pass
of different quantization schemes supported by Quark.
"""
@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
"""
Get minimum device capability.
"""
raise NotImplementedError
@abstractmethod
def create_weights(self, *args, **kwargs):
"""
Weight creation for the particular scheme. Inputs to this function
"""
raise NotImplementedError
@abstractmethod
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]):
"""
Run the forward pass for the particular scheme. This is where
scheme-specific dequant/quant steps/kernels should be applied.
:param layer: torch.nn.Module with the registered weights and
other parameters relevant to the particular scheme.
:param x: input to the layer
:param bias: bias parameter
"""
raise NotImplementedError
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module):
"""
Called after weight loading is complete for any cleanup that
needs to occur.
"""
raise NotImplementedError

View File

@ -0,0 +1,140 @@
from typing import Callable, List, Optional
import torch
from torch.nn import Parameter
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz,
requantize_with_max_scale)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
from vllm.platforms import current_platform
__all__ = ["QuarkW8A8Fp8"]
class QuarkW8A8Fp8(QuarkScheme):
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]):
self.qscheme = qscheme
self.is_static_input_scheme = is_static_input_scheme
self.cutlass_fp8_supported = cutlass_fp8_supported()
@classmethod
def get_min_capability(cls) -> int:
# lovelace and up
return 89
def process_weights_after_loading(self, layer) -> None:
# If per tensor, when we have a fused module (e.g. QKV) with per
# tensor scales (thus N scales being passed to the kernel),
# requantize so we can always run per tensor
if self.qscheme == "per_tensor":
max_w_scale, weight = requantize_with_max_scale(
weight=layer.weight,
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths,
)
if current_platform.is_rocm():
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=max_w_scale,
input_scale=layer.input_scale)
if input_scale is not None:
layer.input_scale = Parameter(input_scale,
requires_grad=False)
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
# If channelwise, scales are already lined up, so just transpose.
elif self.qscheme == "per_channel":
weight = layer.weight
if current_platform.is_rocm():
weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale)
if input_scale is not None:
layer.input_scale = Parameter(input_scale,
requires_grad=False)
else:
weight_scale = layer.weight_scale.data
layer.weight = Parameter(weight.t(), requires_grad=False)
# required by torch.compile to be torch.nn.Parameter
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
else:
raise ValueError(f"Unknown quantization scheme {self.qscheme}")
# INPUT SCALE
if self.is_static_input_scheme:
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
else:
layer.input_scale = None
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
# WEIGHT
weight = ModelWeightParameter(data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
# TODO: update create_xxx_parameter functions to return
# the newly added parameters
if self.qscheme == "per_channel":
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
else:
assert self.qscheme == "per_tensor"
weight_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
# min requirement for fp8 kernels
weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
if self.is_static_input_scheme:
input_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
input_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", input_scale)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return apply_fp8_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported,
use_per_token_if_dynamic=True)

View File

@ -0,0 +1,105 @@
from typing import Callable, List, Optional, Set
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
ScaledMMLinearLayerConfig, choose_scaled_mm_linear_kernel)
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
logger = init_logger(__name__)
class QuarkW8A8Int8(QuarkScheme):
_kernel_backends_being_used: Set[str] = set()
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool],
input_symmetric: Optional[bool]):
self.qscheme = qscheme
self.is_static_input_scheme = is_static_input_scheme
self.input_symmetric = input_symmetric
@classmethod
def get_min_capability(cls) -> int:
# turing and up
return 75
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
self.logical_widths = output_partition_sizes
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
is_channelwise=(self.qscheme == "per_channel"),
is_static_input_scheme=(self.is_static_input_scheme is True),
input_symmetric=(self.input_symmetric is True))
kernel_type = choose_scaled_mm_linear_kernel(
scaled_mm_linear_kernel_config)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# WEIGHT
weight = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=torch.int8),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
if self.qscheme == "per_channel":
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
else:
assert self.qscheme == "per_tensor"
weight_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
if self.is_static_input_scheme:
input_scale = BasevLLMParameter(data=torch.empty(
1, dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale)
if not self.input_symmetric:
# Note: quark stores the zp using the same dtype
# as the weights
# AZP loaded as int8 but used as int32
input_zero_point = BasevLLMParameter(
data=torch.empty(1, dtype=torch.int8),
weight_loader=weight_loader)
layer.register_parameter("input_zero_point", input_zero_point)
self.kernel = kernel_type(c=scaled_mm_linear_kernel_config,
w_q_param_name="weight",
w_s_param_name="weight_scale",
i_s_param_name="input_scale",
i_zp_param_name="input_zero_point",
azp_adj_param_name="azp_adj")
# Checkpoints are serialized in quark format, which is
# different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.kernel.process_weights_after_loading(layer)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return self.kernel.apply_weights(layer, x, bias)

View File

@ -0,0 +1,99 @@
import re
from typing import Any, Iterable, Optional
from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)
def deep_compare(dict1: Any, dict2: Any) -> bool:
if type(dict1) is not type(dict2):
return False
if isinstance(dict1, dict):
if dict1.keys() != dict2.keys():
return False
return all(deep_compare(dict1[k], dict2[k]) for k in dict1)
elif isinstance(dict1, list):
return set(dict1) == set(dict2)
else:
return dict1 == dict2
def should_ignore_layer(layer_name: Optional[str],
ignore: Iterable[str]) -> bool:
if layer_name is None:
return False
# layer_name = model.layers.0.self_attn.qkv_proj
# proj_name = qkv_proj
proj_name = layer_name.split(".")[-1]
# Fused layers like gate_up_proj or qkv_proj will not be fused
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in FUSED_LAYER_NAME_MAPPING:
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name]
# Convert fused_name --> [shard_names]
shard_names = [
layer_name.replace(proj_name, shard_proj_name)
for shard_proj_name in shard_proj_names
]
# Layer should be ignored if shards are ignored.
should_ignore_layer = None
for shard_name in shard_names:
should_ignore_shard = check_equal_or_regex_match(
layer_name=shard_name, targets=ignore)
# If shard_idx=0, set layer ignore to match shard.
if should_ignore_layer is None:
should_ignore_layer = should_ignore_shard
# If shard_idx=1+ confirm scheme matches prior shards.
elif should_ignore_shard != should_ignore_layer:
raise ValueError(f"Found a different quantization schemes for "
f"{shard_proj_names} in {layer_name}. vLLM "
"requires all to use the same scheme.")
# Unfused layers like down_proj and o_proj will match
# the safetensors checkpoint already.
else:
should_ignore_layer = check_equal_or_regex_match(layer_name=layer_name,
targets=ignore)
assert should_ignore_layer is not None
return should_ignore_layer
def check_equal_or_regex_match(layer_name: str,
targets: Iterable[str]) -> bool:
"""
Checks whether a layer_name is exactly equal or a regex match for
if target starts with 're:' to any target in list.
"""
for target in targets:
if _is_equal_or_regex_match(layer_name, target):
return True
return False
def _is_equal_or_regex_match(value: str,
target: str,
check_contains: bool = False) -> bool:
"""
Checks whether a value is exactly equal or a regex match for target
if target starts with 're:'. If check_contains is set to True,
additionally checks if the target string is contained within the value.
"""
if target.startswith("re:"):
pattern = target[3:]
if re.match(pattern, value):
return True
elif check_contains:
if target.lower() in value.lower():
return True
elif target == value:
return True
return False

View File

@ -13,8 +13,6 @@ from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)
from vllm.model_executor.layers.sampler import (SamplerOutput,
SamplingMetadata, get_sampler)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
@ -390,12 +388,15 @@ class AriaMoELMModel(LlamaModel):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if scale_name := get_compressed_tensors_cache_scale(name):
# Loading kv cache scales for compressed-tensors quantization
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache scales for quark and
# compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight[0]
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue

View File

@ -437,6 +437,20 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache scales for quark and
# compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name:
continue

View File

@ -83,7 +83,7 @@ class DbrxExperts(FusedMoE):
# Define custom weight loader for dbrx model
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
weight_name: str):
weight_name: str, param_name: str):
tp_rank = get_tensor_model_parallel_rank()
param_data = param.data
shard_size = self.intermediate_size
@ -91,25 +91,37 @@ class DbrxExperts(FusedMoE):
# DBRX uses GLU for each experts.
# GLU has 3 linear layers: w1, v1 and w2.
if weight_name.endswith("w1"):
if param_name.endswith("weight"):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model],
)
param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :]
elif param_name.endswith("weight_scale"):
param_data[:, 0] = loaded_weight
else:
param_data = loaded_weight
if weight_name.endswith("v1"):
if param_name.endswith("weight"):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model],
)
param_data[:,
shard_size:2 * shard_size, :] = loaded_weight[:,
shard, :]
param_data[:, shard_size:2 *
shard_size, :] = loaded_weight[:, shard, :]
elif param_name.endswith("weight_scale"):
param_data[:, 1] = loaded_weight
else:
param_data[:] = loaded_weight
if weight_name.endswith("w2"):
if param_name.endswith("weight"):
loaded_weight = torch.reshape(
loaded_weight,
[-1, self.intermediate_size * self.tp_size, self.d_model],
).transpose(1, 2)
param_data[:] = loaded_weight[:, :, shard]
else:
param_data[:] = loaded_weight
class DbrxMoE(nn.Module):
@ -430,14 +442,29 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
expert_params_mapping = [(
"w13_weight" if weight_name in ["w1", "v1"] else "w2_weight",
"w13" if weight_name in ["w1", "v1"] else "w2",
f"mlp.{weight_name}",
) for weight_name in ["w1", "v1", "w2"]]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache scales for quark and
# compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if name.endswith(("w1", "w2", "v1")):
name = name + "_weight"
for param_name, weight_name in expert_params_mapping:
if weight_name not in name:
continue
@ -446,8 +473,9 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, weight_name)
weight_loader(param, loaded_weight, weight_name, name)
break
else:
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
@ -456,6 +484,9 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
if is_pp_missing_parameter(name, self):
continue
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)

View File

@ -39,8 +39,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -439,6 +437,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.transformer = ExaoneModel(
vllm_config=vllm_config,
@ -532,12 +531,15 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
if scale_name := get_compressed_tensors_cache_scale(name):
# Loading kv cache scales for compressed-tensors quantization
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache scales for quark and
# compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight[0]
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue

View File

@ -31,8 +31,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -254,6 +252,7 @@ class Gemma2Model(nn.Module):
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
@ -329,7 +328,8 @@ class Gemma2Model(nn.Module):
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if scale_name := get_compressed_tensors_cache_scale(name):
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache scales for compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",

View File

@ -313,6 +313,20 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
for name, loaded_weight in weights:
if "attn.bias" in name or "attn.masked_bias" in name:
continue
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache scales for quark and
# compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue

View File

@ -39,8 +39,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -371,6 +369,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = GraniteModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
@ -474,12 +473,15 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# processed with quantization, LoRA, fine-tuning, etc.
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
if scale_name := get_compressed_tensors_cache_scale(name):
# Loading kv cache scales for compressed-tensors quantization
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache scales for quark and
# compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight[0]
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue

View File

@ -38,8 +38,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -306,6 +304,7 @@ class LlamaModel(nn.Module):
lora_config = vllm_config.lora_config
self.config = config
self.quant_config = quant_config
self.padding_idx = config.pad_token_id
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
@ -396,12 +395,15 @@ class LlamaModel(nn.Module):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if scale_name := get_compressed_tensors_cache_scale(name):
# Loading kv cache scales for compressed-tensors quantization
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache scales for quark and
# compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight[0]
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue

View File

@ -347,6 +347,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = MixtralModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
@ -428,6 +429,19 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if "rotary_emb.inv_freq" in name:
continue
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache scales for quark and
# compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue

View File

@ -1116,6 +1116,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.quant_config = quant_config
self.vocab_size = config.text_config.vocab_size
self.hidden_size = config.text_config.hidden_size
self.max_num_tiles = config.vision_config.max_num_tiles
@ -1429,6 +1430,18 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
name = name.replace('patch_embedding.weight',
'patch_embedding._linear.weight')
loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1)
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache scales for quark and
# compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
updated_params.add(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue

View File

@ -405,6 +405,7 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = NemotronModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
@ -489,6 +490,18 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache scales for quark and
# compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue

View File

@ -546,6 +546,7 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = vllm_config.quant_config
self.model = PhiMoEModel(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
@ -623,6 +624,19 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
if "rotary_emb.inv_freq" in name:
continue
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache scales for quark and
# compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue

View File

@ -279,6 +279,7 @@ class Qwen2Model(nn.Module):
))
self.config = config
self.quant_config = quant_config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
@ -364,6 +365,18 @@ class Qwen2Model(nn.Module):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache scales for quark and
# compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue

View File

@ -39,8 +39,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -409,6 +407,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.model = SolarModel(
vllm_config=vllm_config,
@ -491,12 +490,15 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if scale_name := get_compressed_tensors_cache_scale(name):
# Loading kv cache scales for compressed-tensors quantization
if (self.quant_config is not None and
(scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache scales for quark and
# compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight[0]
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue

View File

@ -56,8 +56,14 @@ class BasevLLMParameter(Parameter):
def weight_loader(self):
return self._weight_loader
def _is_1d_and_scalar(self, loaded_weight: torch.Tensor):
cond1 = self.data.ndim == 1 and self.data.numel() == 1
cond2 = loaded_weight.ndim == 0 and loaded_weight.numel() == 1
return (cond1 and cond2)
def _assert_and_load(self, loaded_weight: torch.Tensor):
assert self.data.shape == loaded_weight.shape
assert (self.data.shape == loaded_weight.shape
or self._is_1d_and_scalar(loaded_weight))
self.data.copy_(loaded_weight)
def load_column_parallel_weight(self, loaded_weight: torch.Tensor):

View File

@ -70,7 +70,7 @@ class RocmPlatform(Platform):
supported_quantization: list[str] = [
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
"fbgemm_fp8", "gguf"
"fbgemm_fp8", "gguf", "quark"
]
@classmethod