[Misc][Quark] Upstream Quark format to VLLM (#10765)
Signed-off-by: kewang-xlnx <kewang@xilinx.com> Signed-off-by: kewang2 <kewang2@amd.com> Co-authored-by: kewang2 <kewang2@amd.com> Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
parent
5ecf3e0aaf
commit
de0526f668
30
tests/quantization/test_quark.py
Normal file
30
tests/quantization/test_quark.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
"""Test model set-up and weight loading for quark-quantized models.
|
||||||
|
|
||||||
|
Run `pytest tests/quantization/test_quark.py`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
|
||||||
|
QuarkLinearMethod, QuarkW8A8Fp8)
|
||||||
|
|
||||||
|
|
||||||
|
def test_quark_fp8(vllm_runner):
|
||||||
|
model_path = "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test"
|
||||||
|
with vllm_runner(model_path) as llm:
|
||||||
|
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||||
|
layer = model.model.layers[0]
|
||||||
|
|
||||||
|
qkv_proj = layer.self_attn.qkv_proj
|
||||||
|
|
||||||
|
assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
|
||||||
|
assert isinstance(qkv_proj.scheme, QuarkW8A8Fp8)
|
||||||
|
|
||||||
|
if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
|
||||||
|
assert len(qkv_proj.input_scale.shape) == 0
|
||||||
|
assert qkv_proj.weight.dtype is torch.float8_e4m3fn
|
||||||
|
#assert qkv_proj.weight.dtype is torch.float8_e4m3fnuz
|
||||||
|
assert len(qkv_proj.weight_scale.shape) == 0
|
||||||
|
|
||||||
|
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||||
|
assert output
|
@ -553,7 +553,7 @@ class ModelConfig:
|
|||||||
optimized_quantization_methods = [
|
optimized_quantization_methods = [
|
||||||
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
|
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
|
||||||
"awq_marlin", "fbgemm_fp8", "compressed_tensors",
|
"awq_marlin", "fbgemm_fp8", "compressed_tensors",
|
||||||
"compressed-tensors", "experts_int8"
|
"compressed-tensors", "experts_int8", "quark"
|
||||||
]
|
]
|
||||||
if self.quantization is not None:
|
if self.quantization is not None:
|
||||||
self.quantization = self.quantization.lower()
|
self.quantization = self.quantization.lower()
|
||||||
|
@ -32,7 +32,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
|
|||||||
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
|
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
|
||||||
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
|
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
|
||||||
"ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod",
|
"ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod",
|
||||||
"HQQMarlinMethod"
|
"HQQMarlinMethod", "QuarkLinearMethod"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,6 +26,7 @@ QUANTIZATION_METHODS: List[str] = [
|
|||||||
"experts_int8",
|
"experts_int8",
|
||||||
"neuron_quant",
|
"neuron_quant",
|
||||||
"ipex",
|
"ipex",
|
||||||
|
"quark"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -34,6 +35,8 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|||||||
raise ValueError(f"Invalid quantization method: {quantization}")
|
raise ValueError(f"Invalid quantization method: {quantization}")
|
||||||
|
|
||||||
# lazy import to avoid triggering `torch.compile` too early
|
# lazy import to avoid triggering `torch.compile` too early
|
||||||
|
from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig
|
||||||
|
|
||||||
from .aqlm import AQLMConfig
|
from .aqlm import AQLMConfig
|
||||||
from .awq import AWQConfig
|
from .awq import AWQConfig
|
||||||
from .awq_marlin import AWQMarlinConfig
|
from .awq_marlin import AWQMarlinConfig
|
||||||
@ -79,6 +82,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|||||||
"experts_int8": ExpertsInt8Config,
|
"experts_int8": ExpertsInt8Config,
|
||||||
"neuron_quant": NeuronQuantConfig,
|
"neuron_quant": NeuronQuantConfig,
|
||||||
"ipex": IPEXConfig,
|
"ipex": IPEXConfig,
|
||||||
|
"quark": QuarkConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
return method_to_config[quantization]
|
return method_to_config[quantization]
|
||||||
|
@ -133,3 +133,6 @@ class QuantizationConfig(ABC):
|
|||||||
method.
|
method.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_cache_scale(self, name: str) -> Optional[str]:
|
||||||
|
return None
|
||||||
|
@ -412,6 +412,22 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
self._check_scheme_supported(scheme.get_min_capability())
|
self._check_scheme_supported(scheme.get_min_capability())
|
||||||
return scheme
|
return scheme
|
||||||
|
|
||||||
|
def get_cache_scale(self, name: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Check whether the param name matches the format for k/v cache scales
|
||||||
|
in compressed-tensors. If this is the case, return its equivalent
|
||||||
|
param name expected by vLLM
|
||||||
|
|
||||||
|
:param name: param name
|
||||||
|
:return: matching param name for KV cache scale in vLLM
|
||||||
|
"""
|
||||||
|
if name.endswith(".output_scale") and ".k_proj" in name:
|
||||||
|
return name.replace(".k_proj.output_scale", ".attn.k_scale")
|
||||||
|
if name.endswith(".output_scale") and ".v_proj" in name:
|
||||||
|
return name.replace(".v_proj.output_scale", ".attn.v_scale")
|
||||||
|
# If no matches, return None
|
||||||
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def supports_cutlass_24(
|
def supports_cutlass_24(
|
||||||
weight_quant: Optional[QuantizationArgs],
|
weight_quant: Optional[QuantizationArgs],
|
||||||
|
@ -136,6 +136,10 @@ def triton_scaled_mm(input: torch.Tensor,
|
|||||||
assert N > 0 and K > 0 and M > 0
|
assert N > 0 and K > 0 and M > 0
|
||||||
assert weight.shape[0] == K
|
assert weight.shape[0] == K
|
||||||
assert input.dtype == weight.dtype
|
assert input.dtype == weight.dtype
|
||||||
|
|
||||||
|
scale_a = scale_a.reshape(-1, 1) if scale_a.dim() <= 1 else scale_a
|
||||||
|
scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b
|
||||||
|
|
||||||
assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point()
|
assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point()
|
||||||
assert scale_a.shape == torch.Size([1, 1]) or scale_a.shape == torch.Size(
|
assert scale_a.shape == torch.Size([1, 1]) or scale_a.shape == torch.Size(
|
||||||
[M, 1])
|
[M, 1])
|
||||||
|
@ -133,23 +133,6 @@ def _find_first_match(value: str,
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_compressed_tensors_cache_scale(name: str) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Check whether the param name matches the format for k/v cache scales
|
|
||||||
in compressed-tensors. If this is the case, return its equivalent
|
|
||||||
param name expected by vLLM
|
|
||||||
|
|
||||||
:param name: param name
|
|
||||||
:return: matching param name for KV cache scale in vLLM
|
|
||||||
"""
|
|
||||||
if name.endswith(".output_scale") and ".k_proj" in name:
|
|
||||||
return name.replace(".k_proj.output_scale", ".attn.k_scale")
|
|
||||||
if name.endswith(".output_scale") and ".v_proj" in name:
|
|
||||||
return name.replace(".v_proj.output_scale", ".attn.v_scale")
|
|
||||||
# If no matches, return None
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _is_equal_or_regex_match(value: str,
|
def _is_equal_or_regex_match(value: str,
|
||||||
target: str,
|
target: str,
|
||||||
check_contains: bool = False) -> bool:
|
check_contains: bool = False) -> bool:
|
||||||
|
387
vllm/model_executor/layers/quantization/quark/quark.py
Normal file
387
vllm/model_executor/layers/quantization/quark/quark.py
Normal file
@ -0,0 +1,387 @@
|
|||||||
|
import fnmatch
|
||||||
|
import re
|
||||||
|
from typing import Any, Dict, List, Optional, cast
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
|
UnquantizedLinearMethod)
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
||||||
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
|
from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
|
||||||
|
QuarkMoEMethod)
|
||||||
|
from vllm.model_executor.layers.quantization.quark.schemes import (
|
||||||
|
QuarkScheme, QuarkW8A8Fp8, QuarkW8A8Int8)
|
||||||
|
from vllm.model_executor.layers.quantization.quark.utils import (
|
||||||
|
deep_compare, should_ignore_layer)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
FUSED_LAYER_NAME_MAPPING)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
__all__ = ["QuarkLinearMethod"]
|
||||||
|
|
||||||
|
|
||||||
|
class QuarkConfig(QuantizationConfig):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
quant_config: Dict[str, Any],
|
||||||
|
kv_cache_group: Optional[List[str]] = None,
|
||||||
|
kv_cache_config: Optional[Dict[str, Any]] = None,
|
||||||
|
pack_method: str = "reorder"):
|
||||||
|
if kv_cache_group is None:
|
||||||
|
kv_cache_group = []
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.kv_cache_group = kv_cache_group
|
||||||
|
self.kv_cache_config = kv_cache_config
|
||||||
|
self.pack_method = pack_method
|
||||||
|
|
||||||
|
def get_linear_method(self) -> "QuarkLinearMethod":
|
||||||
|
return QuarkLinearMethod(self)
|
||||||
|
|
||||||
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||||
|
return [torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
return 70
|
||||||
|
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "quark"
|
||||||
|
|
||||||
|
def get_quant_method(self, layer: torch.nn.Module,
|
||||||
|
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||||
|
from vllm.attention.layer import Attention # Avoid circular import
|
||||||
|
|
||||||
|
# Check if the layer is skipped for quantization.
|
||||||
|
exclude_layers = cast(List[str], self.quant_config.get("exclude"))
|
||||||
|
if should_ignore_layer(prefix, ignore=exclude_layers):
|
||||||
|
return UnquantizedLinearMethod()
|
||||||
|
if isinstance(layer, LinearBase):
|
||||||
|
scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
||||||
|
layer.scheme = scheme
|
||||||
|
return QuarkLinearMethod(self)
|
||||||
|
if isinstance(layer, Attention):
|
||||||
|
return QuarkKVCacheMethod(self)
|
||||||
|
if isinstance(layer, FusedMoE):
|
||||||
|
return QuarkMoEMethod.get_moe_method(self,
|
||||||
|
module=layer,
|
||||||
|
layer_name=prefix)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig":
|
||||||
|
export_config = config.get("export")
|
||||||
|
if export_config is None:
|
||||||
|
raise ValueError("The export key should be included in "
|
||||||
|
"the configurations of Quark quantized model")
|
||||||
|
kv_cache_group = cast(List[str], export_config.get("kv_cache_group"))
|
||||||
|
pack_method = cast(str, export_config.get("pack_method"))
|
||||||
|
|
||||||
|
# In the export model of quark, the quantization configuration
|
||||||
|
# of kv_cache is stored in layer_quant_config. First, it is
|
||||||
|
# judged whether kv_cache_group exists, and then it is judged
|
||||||
|
# whether layer_quant_config has a quantization configuration
|
||||||
|
# that matches kv_cache.
|
||||||
|
if len(kv_cache_group) == 0:
|
||||||
|
kv_cache_config = None
|
||||||
|
else:
|
||||||
|
kv_cache_set = set(kv_cache_group)
|
||||||
|
layer_quant_config = cast(Dict[str, Any],
|
||||||
|
config.get("layer_quant_config"))
|
||||||
|
layer_quant_names = list(layer_quant_config.keys())
|
||||||
|
layer_quant_set = set(layer_quant_names)
|
||||||
|
|
||||||
|
if not kv_cache_set.issubset(layer_quant_set):
|
||||||
|
raise ValueError("The Quark quantized model has the "
|
||||||
|
"kv_cache_group parameter setting, "
|
||||||
|
"but no kv_cache quantization settings "
|
||||||
|
"were found in the quantization "
|
||||||
|
"configuration.")
|
||||||
|
|
||||||
|
q_configs = [
|
||||||
|
cast(Dict[str, Any], layer_quant_config.get(name))
|
||||||
|
for name in kv_cache_group
|
||||||
|
]
|
||||||
|
if not all(
|
||||||
|
deep_compare(q_config, q_configs[0])
|
||||||
|
for q_config in q_configs):
|
||||||
|
raise ValueError(
|
||||||
|
"The quantization method used for kv_cache should "
|
||||||
|
"be the same, but the quantization method for the "
|
||||||
|
"kv_cache layer in the config is different.")
|
||||||
|
kv_cache_config = q_configs[0].get("output_tensors")
|
||||||
|
if kv_cache_config is None:
|
||||||
|
raise ValueError(
|
||||||
|
"The kv_cache quantization configuration is empty.")
|
||||||
|
|
||||||
|
# Since we have already set kv_cache quantization configurations,
|
||||||
|
# we will remove the quantization configuration for the
|
||||||
|
# output_tensors corresponding to the kv_cache layer.
|
||||||
|
for q_config in q_configs:
|
||||||
|
q_config["output_tensors"] = None
|
||||||
|
|
||||||
|
return cls(quant_config=config,
|
||||||
|
kv_cache_group=kv_cache_group,
|
||||||
|
kv_cache_config=kv_cache_config,
|
||||||
|
pack_method=pack_method)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_config_filenames(cls) -> List[str]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _check_scheme_supported(self,
|
||||||
|
min_capability: int,
|
||||||
|
error: bool = True) -> bool:
|
||||||
|
capability_tuple = current_platform.get_device_capability()
|
||||||
|
|
||||||
|
if capability_tuple is not None:
|
||||||
|
capability = capability_tuple.to_int()
|
||||||
|
supported = capability >= min_capability
|
||||||
|
if error and not supported:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Quantization scheme is not supported for ",
|
||||||
|
f"the current GPU. Min capability: {min_capability}. ",
|
||||||
|
f"Current capability: {capability}.")
|
||||||
|
return supported
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _is_fp8_w8a8(self, weight_quant: Optional[Dict[str, Any]],
|
||||||
|
input_quant: Optional[Dict[str, Any]]) -> bool:
|
||||||
|
# Confirm weights and input quantized.
|
||||||
|
if weight_quant is None or input_quant is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Confirm weight scheme is supported
|
||||||
|
is_fp8_dtype = (weight_quant.get("dtype") == "fp8_e4m3"
|
||||||
|
and input_quant.get("dtype") == "fp8_e4m3")
|
||||||
|
is_static_weight = not weight_quant.get("is_dynamic")
|
||||||
|
is_per_tensor_or_channel_weight = (weight_quant.get("qscheme")
|
||||||
|
in ["per_tensor", "per_channel"])
|
||||||
|
|
||||||
|
if not (is_fp8_dtype and is_static_weight
|
||||||
|
and is_per_tensor_or_channel_weight):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Dynamic quantization is always supported if weights supported.
|
||||||
|
if input_quant.get("is_dynamic"):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Confirm activation scheme is supported.
|
||||||
|
is_per_tensor_activation = (input_quant.get("qscheme") == "per_tensor")
|
||||||
|
return is_per_tensor_activation
|
||||||
|
|
||||||
|
def _is_static_tensor_w8a8(self, weight_quant: Optional[Dict[str, Any]],
|
||||||
|
input_quant: Optional[Dict[str, Any]]) -> bool:
|
||||||
|
# Confirm weights and input quantized.
|
||||||
|
if weight_quant is None or input_quant is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
is_int8_dtype = (weight_quant.get("dtype") == "int8"
|
||||||
|
and input_quant.get("dtype") == "int8")
|
||||||
|
|
||||||
|
is_tensor = (weight_quant.get("qscheme")
|
||||||
|
in ["per_tensor", "per_channel"]
|
||||||
|
and input_quant.get("qscheme") == "per_tensor")
|
||||||
|
|
||||||
|
is_static = (not weight_quant.get("is_dynamic")
|
||||||
|
and not input_quant.get("is_dynamic"))
|
||||||
|
|
||||||
|
is_weight_symmetric = (weight_quant.get("symmetric") is True)
|
||||||
|
|
||||||
|
# Both symmetric and asymmetric input quantization supported.
|
||||||
|
# Only symmetric weight quantization supported.
|
||||||
|
return is_int8_dtype and is_tensor and is_weight_symmetric and is_static
|
||||||
|
|
||||||
|
def _find_matched_config(self, layer_name: str,
|
||||||
|
module: torch.nn.Module) -> Dict[str, Any]:
|
||||||
|
|
||||||
|
proj_name = layer_name.split(".")[-1]
|
||||||
|
if proj_name in FUSED_LAYER_NAME_MAPPING:
|
||||||
|
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name]
|
||||||
|
|
||||||
|
# Convert fused_name --> [shard_names]
|
||||||
|
shard_names = [
|
||||||
|
layer_name.replace(proj_name, shard_proj_name)
|
||||||
|
for shard_proj_name in shard_proj_names
|
||||||
|
]
|
||||||
|
shard_configs = [
|
||||||
|
self._find_matched_config(shard_name, module)
|
||||||
|
for shard_name in shard_names
|
||||||
|
]
|
||||||
|
if not all(
|
||||||
|
deep_compare(q_config, shard_configs[0])
|
||||||
|
for q_config in shard_configs):
|
||||||
|
raise ValueError(
|
||||||
|
f"Found a different quantization configuration for "
|
||||||
|
f"{shard_proj_names} in {layer_name}. vLLM "
|
||||||
|
"requires all to use the same scheme.")
|
||||||
|
return shard_configs[0]
|
||||||
|
else:
|
||||||
|
layer_quant_config = cast(
|
||||||
|
Dict[str, Any], self.quant_config.get("layer_quant_config"))
|
||||||
|
for name_pattern in layer_quant_config:
|
||||||
|
if fnmatch.fnmatch(layer_name, name_pattern):
|
||||||
|
return layer_quant_config[name_pattern]
|
||||||
|
|
||||||
|
layer_type = cast(str, type(module))
|
||||||
|
layer_type_quant_config = cast(
|
||||||
|
Dict[str, Any],
|
||||||
|
self.quant_config.get("layer_type_quant_config"))
|
||||||
|
if layer_type in layer_type_quant_config:
|
||||||
|
return layer_type_quant_config[layer_type]
|
||||||
|
|
||||||
|
global_quant_config = cast(
|
||||||
|
Dict[str, Any], self.quant_config.get("global_quant_config"))
|
||||||
|
return global_quant_config
|
||||||
|
|
||||||
|
def _get_scheme_from_config(self, config: Dict[str, Any]) -> "QuarkScheme":
|
||||||
|
if config.get("output_tensors") or config.get("bias"):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Currently, Quark models with output_tensors "
|
||||||
|
"and bias quantized are not supported")
|
||||||
|
weight_config = cast(Dict[str, Any], config.get("weight"))
|
||||||
|
input_config = cast(Dict[str, Any], config.get("input_tensors"))
|
||||||
|
|
||||||
|
if self._is_fp8_w8a8(weight_config, input_config):
|
||||||
|
is_fp8_w8a8_supported = self._check_scheme_supported(
|
||||||
|
QuarkW8A8Fp8.get_min_capability(), error=False)
|
||||||
|
if is_fp8_w8a8_supported:
|
||||||
|
weight_qscheme = cast(str, weight_config.get("qscheme"))
|
||||||
|
input_static = (input_config is not None and
|
||||||
|
not cast(bool, input_config.get("is_dynamic")))
|
||||||
|
return QuarkW8A8Fp8(qscheme=weight_qscheme,
|
||||||
|
is_static_input_scheme=input_static)
|
||||||
|
elif self._is_static_tensor_w8a8(weight_config, input_config):
|
||||||
|
weight_qscheme = cast(str, weight_config.get("qscheme"))
|
||||||
|
return QuarkW8A8Int8(qscheme=weight_qscheme,
|
||||||
|
is_static_input_scheme=True,
|
||||||
|
input_symmetric=input_config.get("symmetric"))
|
||||||
|
|
||||||
|
raise NotImplementedError("No quark compatible scheme was found. "
|
||||||
|
f"Weight config: {weight_config}, "
|
||||||
|
f"Input config: {input_config}")
|
||||||
|
|
||||||
|
def get_scheme(self, layer: torch.nn.Module,
|
||||||
|
layer_name: str) -> "QuarkScheme":
|
||||||
|
|
||||||
|
layer_quant_config = self._find_matched_config(layer_name, layer)
|
||||||
|
|
||||||
|
# Find the quant_scheme
|
||||||
|
scheme = self._get_scheme_from_config(layer_quant_config)
|
||||||
|
# Raise error if device does not support the scheme
|
||||||
|
# (e.g. fp8 needs ada lovelace)
|
||||||
|
self._check_scheme_supported(scheme.get_min_capability())
|
||||||
|
|
||||||
|
return scheme
|
||||||
|
|
||||||
|
def get_cache_scale(self, name: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Check whether the param name matches the format for k/v cache scales
|
||||||
|
in quark. If this is the case, return its equivalent param name
|
||||||
|
expected by vLLM
|
||||||
|
|
||||||
|
:param name: param name
|
||||||
|
:return: matching param name for KV cache scale in vLLM
|
||||||
|
"""
|
||||||
|
if self.kv_cache_group is None or len(self.kv_cache_group) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
kv_proj_names = [
|
||||||
|
re.split(r"[*.]", kv_cache)[-1] for kv_cache in self.kv_cache_group
|
||||||
|
]
|
||||||
|
if name.endswith(".output_scale"):
|
||||||
|
if len(kv_proj_names) == 1 and kv_proj_names[0] in name:
|
||||||
|
kv_output_scale_name = "." + kv_proj_names[0] + ".output_scale"
|
||||||
|
return name.replace(kv_output_scale_name, ".attn.k_scale")
|
||||||
|
|
||||||
|
elif len(kv_proj_names) == 2:
|
||||||
|
for kv_proj_name in kv_proj_names:
|
||||||
|
if kv_proj_name in name and kv_proj_name == "k_proj":
|
||||||
|
return name.replace(".k_proj.output_scale",
|
||||||
|
".attn.k_scale")
|
||||||
|
elif kv_proj_name in name and kv_proj_name == "v_proj":
|
||||||
|
return name.replace(".v_proj.output_scale",
|
||||||
|
".attn.v_scale")
|
||||||
|
|
||||||
|
# If no matches, return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class QuarkLinearMethod(LinearMethodBase):
|
||||||
|
|
||||||
|
def __init__(self, quantization_config: QuarkConfig):
|
||||||
|
self.quantization_config = quantization_config
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
layer.scheme.process_weights_after_loading(layer)
|
||||||
|
|
||||||
|
def create_weights(self, layer: torch.nn.Module,
|
||||||
|
input_size_per_partition: int,
|
||||||
|
output_partition_sizes: List[int], input_size: int,
|
||||||
|
output_size: int, params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs):
|
||||||
|
"""
|
||||||
|
Use the CompressedTensorsScheme associated with each layer to create
|
||||||
|
the necessary parameters for the layer. See LinearMethodBase for param
|
||||||
|
details
|
||||||
|
"""
|
||||||
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||||
|
layer.scheme.create_weights(
|
||||||
|
layer=layer,
|
||||||
|
input_size=input_size,
|
||||||
|
input_size_per_partition=input_size_per_partition,
|
||||||
|
output_partition_sizes=output_partition_sizes,
|
||||||
|
output_size=output_size,
|
||||||
|
params_dtype=params_dtype,
|
||||||
|
weight_loader=weight_loader)
|
||||||
|
|
||||||
|
def apply(self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None):
|
||||||
|
"""
|
||||||
|
Use the output of create_weights and the CompressedTensorsScheme
|
||||||
|
associated with the layer to apply the forward pass with the
|
||||||
|
layer input. See LinearMethodBase for param details
|
||||||
|
|
||||||
|
"""
|
||||||
|
scheme = layer.scheme
|
||||||
|
if scheme is None:
|
||||||
|
raise ValueError("A scheme must be defined for each layer")
|
||||||
|
return scheme.apply_weights(layer, x, bias=bias)
|
||||||
|
|
||||||
|
|
||||||
|
class QuarkKVCacheMethod(BaseKVCacheMethod):
|
||||||
|
"""
|
||||||
|
Supports loading kv-cache scaling factors from quark checkpoints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quant_config: QuarkConfig):
|
||||||
|
self.validate_kv_cache_config(quant_config.kv_cache_config)
|
||||||
|
super().__init__(quant_config)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_kv_cache_config(kv_cache_config: Optional[Dict[str, Any]]):
|
||||||
|
"""
|
||||||
|
Validator for the kv cache configuration. Useful for controlling the
|
||||||
|
kv cache quantization schemes, that are being supported in vLLM
|
||||||
|
:param kv_cache_config: the quark kv cache scheme
|
||||||
|
"""
|
||||||
|
if kv_cache_config is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
dtype = kv_cache_config.get("dtype")
|
||||||
|
if dtype != "fp8_e4m3":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Currently supported kv cache quantization is "
|
||||||
|
f"dtype=fp8_e4m3, however received {dtype}")
|
||||||
|
|
||||||
|
qscheme = kv_cache_config.get("qscheme")
|
||||||
|
if qscheme != "per_tensor":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Only support per-tensor scaling factor "
|
||||||
|
"for quark KV cache. "
|
||||||
|
f"Expected qscheme: per_tensor, found qscheme: {qscheme}")
|
225
vllm/model_executor/layers/quantization/quark/quark_moe.py
Normal file
225
vllm/model_executor/layers/quantization/quark/quark_moe.py
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
from typing import Any, Callable, Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm.model_executor.layers.fused_moe # noqa
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||||
|
FusedMoeWeightScaleSupported)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
|
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
||||||
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod"]
|
||||||
|
|
||||||
|
|
||||||
|
class QuarkMoEMethod(FusedMoEMethodBase):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_moe_method(
|
||||||
|
quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821
|
||||||
|
module: torch.nn.Module,
|
||||||
|
layer_name: str) -> "QuarkMoEMethod":
|
||||||
|
layer_quant_config = quant_config._find_matched_config(
|
||||||
|
layer_name, module)
|
||||||
|
|
||||||
|
if (layer_quant_config.get("output_tensors")
|
||||||
|
or layer_quant_config.get("bias")):
|
||||||
|
raise NotImplementedError("Currently, Quark models with "
|
||||||
|
"output_tensors and bias "
|
||||||
|
"quantized are not supported")
|
||||||
|
weight_config = layer_quant_config.get("weight")
|
||||||
|
input_config = layer_quant_config.get("input_tensors")
|
||||||
|
|
||||||
|
if quant_config._is_fp8_w8a8(weight_config, input_config):
|
||||||
|
return QuarkW8A8Fp8MoEMethod(weight_config, input_config)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Unsupported FusedMoe scheme")
|
||||||
|
|
||||||
|
|
||||||
|
class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||||
|
|
||||||
|
def __init__(self, weight_config: Dict[str, Any], input_config: Dict[str,
|
||||||
|
Any]):
|
||||||
|
self.weight_quant = weight_config
|
||||||
|
self.input_quant = input_config
|
||||||
|
|
||||||
|
weight_qscheme = self.weight_quant.get("qscheme")
|
||||||
|
input_qscheme = self.input_quant.get("qscheme")
|
||||||
|
if not (weight_qscheme == "per_tensor"
|
||||||
|
and input_qscheme == "per_tensor"):
|
||||||
|
raise ValueError(
|
||||||
|
"For FP8 Fused MoE layers, only per-tensor scales"
|
||||||
|
"for weights and activations are supported. Found "
|
||||||
|
f"{weight_qscheme}, {input_qscheme}") # noqa E501
|
||||||
|
|
||||||
|
self.static_input_scales = not self.input_quant.get("is_dynamic")
|
||||||
|
|
||||||
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||||
|
hidden_size: int, intermediate_size: int,
|
||||||
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||||
|
|
||||||
|
params_dtype = torch.float8_e4m3fn
|
||||||
|
|
||||||
|
# WEIGHTS
|
||||||
|
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||||
|
2 * intermediate_size,
|
||||||
|
hidden_size,
|
||||||
|
dtype=params_dtype),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w13_weight", w13_weight)
|
||||||
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size,
|
||||||
|
dtype=params_dtype),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w2_weight", w2_weight)
|
||||||
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||||
|
|
||||||
|
# WEIGHT_SCALES
|
||||||
|
# Allocate 2 scales for w1 and w3 respectively.
|
||||||
|
# They will be combined to a single scale after weight loading.
|
||||||
|
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||||
|
2,
|
||||||
|
dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||||
|
|
||||||
|
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||||
|
dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||||
|
# Add the quantization method used (per tensor/grouped/channel)
|
||||||
|
# to ensure the weight scales are loaded in properly
|
||||||
|
extra_weight_attrs.update(
|
||||||
|
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||||
|
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||||
|
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
# INPUT_SCALES
|
||||||
|
if self.static_input_scales:
|
||||||
|
w13_input_scale = torch.nn.Parameter(torch.ones(
|
||||||
|
num_experts, dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||||
|
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
||||||
|
|
||||||
|
w2_input_scale = torch.nn.Parameter(torch.ones(
|
||||||
|
num_experts, dtype=torch.float32),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||||
|
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||||
|
else:
|
||||||
|
layer.w13_input_scale = None
|
||||||
|
layer.w2_input_scale = None
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
# Fp8 moe kernels require a single activation scale.
|
||||||
|
# We take the max of all the scales in case they differ.
|
||||||
|
if self.static_input_scales:
|
||||||
|
if (layer.w13_input_scale is None or layer.w2_input_scale is None):
|
||||||
|
raise ValueError(
|
||||||
|
"QuantConfig has static quantization, but found "
|
||||||
|
"activation scales are None.")
|
||||||
|
if (not all_close_1d(layer.w13_input_scale)
|
||||||
|
or not all_close_1d(layer.w2_input_scale)):
|
||||||
|
logger.warning_once(
|
||||||
|
"Found input_scales that are not equal for "
|
||||||
|
"fp8 MoE layer. Using the maximum across experts "
|
||||||
|
"for each layer. ")
|
||||||
|
layer.w13_input_scale = torch.nn.Parameter(
|
||||||
|
layer.w13_input_scale.max(), requires_grad=False)
|
||||||
|
layer.w2_input_scale = torch.nn.Parameter(
|
||||||
|
layer.w2_input_scale.max(), requires_grad=False)
|
||||||
|
|
||||||
|
# If rocm, normalize the weights and scales to e4m3fnuz
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
# Normalize the weights and scales
|
||||||
|
w13_weight, w13_weight_scale, w13_input_scale = \
|
||||||
|
normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
layer.w13_weight, layer.w13_weight_scale,
|
||||||
|
layer.w13_input_scale)
|
||||||
|
w2_weight, w2_weight_scale, w2_input_scale = \
|
||||||
|
normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
layer.w2_weight, layer.w2_weight_scale,
|
||||||
|
layer.w2_input_scale)
|
||||||
|
# Reset the parameter
|
||||||
|
layer.w13_weight = torch.nn.Parameter(w13_weight,
|
||||||
|
requires_grad=False)
|
||||||
|
layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale,
|
||||||
|
requires_grad=False)
|
||||||
|
if w13_input_scale is not None:
|
||||||
|
layer.w13_input_scale = torch.nn.Parameter(w13_input_scale,
|
||||||
|
requires_grad=False)
|
||||||
|
layer.w2_weight = torch.nn.Parameter(w2_weight,
|
||||||
|
requires_grad=False)
|
||||||
|
layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
|
||||||
|
requires_grad=False)
|
||||||
|
if w2_input_scale is not None:
|
||||||
|
layer.w2_input_scale = torch.nn.Parameter(w2_input_scale,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
||||||
|
# We take the max then dequant and requant each expert.
|
||||||
|
assert layer.w13_weight_scale is not None
|
||||||
|
shard_size = layer.intermediate_size_per_partition
|
||||||
|
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||||
|
for expert_id in range(layer.num_experts):
|
||||||
|
start = 0
|
||||||
|
for shard_id in range(2):
|
||||||
|
dq_weight = per_tensor_dequantize(
|
||||||
|
layer.w13_weight[expert_id][start:start + shard_size, :],
|
||||||
|
layer.w13_weight_scale[expert_id][shard_id])
|
||||||
|
layer.w13_weight[expert_id][
|
||||||
|
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
|
||||||
|
dq_weight, max_w13_scales[expert_id])
|
||||||
|
start += shard_size
|
||||||
|
|
||||||
|
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
def apply(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
renormalize: bool,
|
||||||
|
use_grouped_topk: bool = False,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
scoring_func: str = "softmax",
|
||||||
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||||
|
|
||||||
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
scoring_func=scoring_func,
|
||||||
|
e_score_correction_bias=e_score_correction_bias)
|
||||||
|
|
||||||
|
return fused_experts(x,
|
||||||
|
layer.w13_weight,
|
||||||
|
layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
inplace=True,
|
||||||
|
use_fp8_w8a8=True,
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
a1_scale=layer.w13_input_scale,
|
||||||
|
a2_scale=layer.w2_input_scale)
|
@ -0,0 +1,5 @@
|
|||||||
|
from .quark_scheme import QuarkScheme
|
||||||
|
from .quark_w8a8_fp8 import QuarkW8A8Fp8
|
||||||
|
from .quark_w8a8_int8 import QuarkW8A8Int8
|
||||||
|
|
||||||
|
__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8"]
|
@ -0,0 +1,52 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
__all__ = ["QuarkScheme"]
|
||||||
|
|
||||||
|
|
||||||
|
class QuarkScheme(ABC):
|
||||||
|
"""
|
||||||
|
Abstract class used to describe the weight creation and forward pass
|
||||||
|
of different quantization schemes supported by Quark.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
"""
|
||||||
|
Get minimum device capability.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def create_weights(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Weight creation for the particular scheme. Inputs to this function
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor]):
|
||||||
|
"""
|
||||||
|
Run the forward pass for the particular scheme. This is where
|
||||||
|
scheme-specific dequant/quant steps/kernels should be applied.
|
||||||
|
|
||||||
|
:param layer: torch.nn.Module with the registered weights and
|
||||||
|
other parameters relevant to the particular scheme.
|
||||||
|
:param x: input to the layer
|
||||||
|
:param bias: bias parameter
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Called after weight loading is complete for any cleanup that
|
||||||
|
needs to occur.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
@ -0,0 +1,140 @@
|
|||||||
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn import Parameter
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||||
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
|
apply_fp8_linear, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz,
|
||||||
|
requantize_with_max_scale)
|
||||||
|
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||||
|
ModelWeightParameter,
|
||||||
|
PerTensorScaleParameter)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
__all__ = ["QuarkW8A8Fp8"]
|
||||||
|
|
||||||
|
|
||||||
|
class QuarkW8A8Fp8(QuarkScheme):
|
||||||
|
|
||||||
|
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]):
|
||||||
|
self.qscheme = qscheme
|
||||||
|
self.is_static_input_scheme = is_static_input_scheme
|
||||||
|
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
# lovelace and up
|
||||||
|
return 89
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer) -> None:
|
||||||
|
# If per tensor, when we have a fused module (e.g. QKV) with per
|
||||||
|
# tensor scales (thus N scales being passed to the kernel),
|
||||||
|
# requantize so we can always run per tensor
|
||||||
|
if self.qscheme == "per_tensor":
|
||||||
|
max_w_scale, weight = requantize_with_max_scale(
|
||||||
|
weight=layer.weight,
|
||||||
|
weight_scale=layer.weight_scale,
|
||||||
|
logical_widths=layer.logical_widths,
|
||||||
|
)
|
||||||
|
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
weight=weight,
|
||||||
|
weight_scale=max_w_scale,
|
||||||
|
input_scale=layer.input_scale)
|
||||||
|
if input_scale is not None:
|
||||||
|
layer.input_scale = Parameter(input_scale,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||||
|
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||||
|
|
||||||
|
# If channelwise, scales are already lined up, so just transpose.
|
||||||
|
elif self.qscheme == "per_channel":
|
||||||
|
weight = layer.weight
|
||||||
|
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
weight, weight_scale, input_scale = \
|
||||||
|
normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
weight=weight,
|
||||||
|
weight_scale=layer.weight_scale,
|
||||||
|
input_scale=layer.input_scale)
|
||||||
|
if input_scale is not None:
|
||||||
|
layer.input_scale = Parameter(input_scale,
|
||||||
|
requires_grad=False)
|
||||||
|
else:
|
||||||
|
weight_scale = layer.weight_scale.data
|
||||||
|
|
||||||
|
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||||
|
# required by torch.compile to be torch.nn.Parameter
|
||||||
|
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown quantization scheme {self.qscheme}")
|
||||||
|
|
||||||
|
# INPUT SCALE
|
||||||
|
if self.is_static_input_scheme:
|
||||||
|
layer.input_scale = Parameter(layer.input_scale.max(),
|
||||||
|
requires_grad=False)
|
||||||
|
else:
|
||||||
|
layer.input_scale = None
|
||||||
|
|
||||||
|
def create_weights(self, layer: torch.nn.Module,
|
||||||
|
output_partition_sizes: List[int],
|
||||||
|
input_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype, weight_loader: Callable,
|
||||||
|
**kwargs):
|
||||||
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
|
layer.logical_widths = output_partition_sizes
|
||||||
|
|
||||||
|
# WEIGHT
|
||||||
|
weight = ModelWeightParameter(data=torch.empty(
|
||||||
|
output_size_per_partition,
|
||||||
|
input_size_per_partition,
|
||||||
|
dtype=torch.float8_e4m3fn),
|
||||||
|
input_dim=1,
|
||||||
|
output_dim=0,
|
||||||
|
weight_loader=weight_loader)
|
||||||
|
layer.register_parameter("weight", weight)
|
||||||
|
|
||||||
|
# WEIGHT SCALE
|
||||||
|
# TODO: update create_xxx_parameter functions to return
|
||||||
|
# the newly added parameters
|
||||||
|
if self.qscheme == "per_channel":
|
||||||
|
weight_scale = ChannelQuantScaleParameter(
|
||||||
|
data=torch.empty((sum(output_partition_sizes), 1),
|
||||||
|
dtype=torch.float32),
|
||||||
|
output_dim=0,
|
||||||
|
weight_loader=weight_loader)
|
||||||
|
else:
|
||||||
|
assert self.qscheme == "per_tensor"
|
||||||
|
weight_scale = PerTensorScaleParameter(data=torch.empty(
|
||||||
|
len(output_partition_sizes), dtype=torch.float32),
|
||||||
|
weight_loader=weight_loader)
|
||||||
|
|
||||||
|
# min requirement for fp8 kernels
|
||||||
|
weight_scale[:] = torch.finfo(torch.float32).min
|
||||||
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
|
|
||||||
|
# INPUT SCALE
|
||||||
|
if self.is_static_input_scheme:
|
||||||
|
input_scale = PerTensorScaleParameter(data=torch.empty(
|
||||||
|
len(output_partition_sizes), dtype=torch.float32),
|
||||||
|
weight_loader=weight_loader)
|
||||||
|
input_scale[:] = torch.finfo(torch.float32).min
|
||||||
|
layer.register_parameter("input_scale", input_scale)
|
||||||
|
|
||||||
|
def apply_weights(self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
|
||||||
|
return apply_fp8_linear(
|
||||||
|
input=x,
|
||||||
|
weight=layer.weight,
|
||||||
|
weight_scale=layer.weight_scale,
|
||||||
|
input_scale=layer.input_scale,
|
||||||
|
bias=bias,
|
||||||
|
cutlass_fp8_supported=self.cutlass_fp8_supported,
|
||||||
|
use_per_token_if_dynamic=True)
|
@ -0,0 +1,105 @@
|
|||||||
|
from typing import Callable, List, Optional, Set
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||||
|
ScaledMMLinearLayerConfig, choose_scaled_mm_linear_kernel)
|
||||||
|
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||||
|
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||||
|
ChannelQuantScaleParameter,
|
||||||
|
ModelWeightParameter,
|
||||||
|
PerTensorScaleParameter)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class QuarkW8A8Int8(QuarkScheme):
|
||||||
|
_kernel_backends_being_used: Set[str] = set()
|
||||||
|
|
||||||
|
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool],
|
||||||
|
input_symmetric: Optional[bool]):
|
||||||
|
self.qscheme = qscheme
|
||||||
|
self.is_static_input_scheme = is_static_input_scheme
|
||||||
|
self.input_symmetric = input_symmetric
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
# turing and up
|
||||||
|
return 75
|
||||||
|
|
||||||
|
def create_weights(self, layer: torch.nn.Module,
|
||||||
|
output_partition_sizes: List[int],
|
||||||
|
input_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype, weight_loader: Callable,
|
||||||
|
**kwargs):
|
||||||
|
self.logical_widths = output_partition_sizes
|
||||||
|
|
||||||
|
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
|
||||||
|
is_channelwise=(self.qscheme == "per_channel"),
|
||||||
|
is_static_input_scheme=(self.is_static_input_scheme is True),
|
||||||
|
input_symmetric=(self.input_symmetric is True))
|
||||||
|
|
||||||
|
kernel_type = choose_scaled_mm_linear_kernel(
|
||||||
|
scaled_mm_linear_kernel_config)
|
||||||
|
|
||||||
|
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||||
|
logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__)
|
||||||
|
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||||
|
|
||||||
|
# WEIGHT
|
||||||
|
weight = ModelWeightParameter(data=torch.empty(
|
||||||
|
sum(output_partition_sizes),
|
||||||
|
input_size_per_partition,
|
||||||
|
dtype=torch.int8),
|
||||||
|
input_dim=1,
|
||||||
|
output_dim=0,
|
||||||
|
weight_loader=weight_loader)
|
||||||
|
|
||||||
|
layer.register_parameter("weight", weight)
|
||||||
|
|
||||||
|
# WEIGHT SCALE
|
||||||
|
if self.qscheme == "per_channel":
|
||||||
|
weight_scale = ChannelQuantScaleParameter(
|
||||||
|
data=torch.empty((sum(output_partition_sizes), 1),
|
||||||
|
dtype=torch.float32),
|
||||||
|
output_dim=0,
|
||||||
|
weight_loader=weight_loader)
|
||||||
|
else:
|
||||||
|
assert self.qscheme == "per_tensor"
|
||||||
|
weight_scale = PerTensorScaleParameter(data=torch.empty(
|
||||||
|
len(output_partition_sizes), dtype=torch.float32),
|
||||||
|
weight_loader=weight_loader)
|
||||||
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
|
|
||||||
|
# INPUT SCALE
|
||||||
|
if self.is_static_input_scheme:
|
||||||
|
input_scale = BasevLLMParameter(data=torch.empty(
|
||||||
|
1, dtype=torch.float32),
|
||||||
|
weight_loader=weight_loader)
|
||||||
|
layer.register_parameter("input_scale", input_scale)
|
||||||
|
|
||||||
|
if not self.input_symmetric:
|
||||||
|
# Note: quark stores the zp using the same dtype
|
||||||
|
# as the weights
|
||||||
|
# AZP loaded as int8 but used as int32
|
||||||
|
input_zero_point = BasevLLMParameter(
|
||||||
|
data=torch.empty(1, dtype=torch.int8),
|
||||||
|
weight_loader=weight_loader)
|
||||||
|
layer.register_parameter("input_zero_point", input_zero_point)
|
||||||
|
|
||||||
|
self.kernel = kernel_type(c=scaled_mm_linear_kernel_config,
|
||||||
|
w_q_param_name="weight",
|
||||||
|
w_s_param_name="weight_scale",
|
||||||
|
i_s_param_name="input_scale",
|
||||||
|
i_zp_param_name="input_zero_point",
|
||||||
|
azp_adj_param_name="azp_adj")
|
||||||
|
|
||||||
|
# Checkpoints are serialized in quark format, which is
|
||||||
|
# different from the format the kernel may want. Handle repacking here.
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
self.kernel.process_weights_after_loading(layer)
|
||||||
|
|
||||||
|
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
|
return self.kernel.apply_weights(layer, x, bias)
|
99
vllm/model_executor/layers/quantization/quark/utils.py
Normal file
99
vllm/model_executor/layers/quantization/quark/utils.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
import re
|
||||||
|
from typing import Any, Iterable, Optional
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
FUSED_LAYER_NAME_MAPPING)
|
||||||
|
|
||||||
|
|
||||||
|
def deep_compare(dict1: Any, dict2: Any) -> bool:
|
||||||
|
if type(dict1) is not type(dict2):
|
||||||
|
return False
|
||||||
|
if isinstance(dict1, dict):
|
||||||
|
if dict1.keys() != dict2.keys():
|
||||||
|
return False
|
||||||
|
return all(deep_compare(dict1[k], dict2[k]) for k in dict1)
|
||||||
|
elif isinstance(dict1, list):
|
||||||
|
return set(dict1) == set(dict2)
|
||||||
|
else:
|
||||||
|
return dict1 == dict2
|
||||||
|
|
||||||
|
|
||||||
|
def should_ignore_layer(layer_name: Optional[str],
|
||||||
|
ignore: Iterable[str]) -> bool:
|
||||||
|
if layer_name is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# layer_name = model.layers.0.self_attn.qkv_proj
|
||||||
|
# proj_name = qkv_proj
|
||||||
|
proj_name = layer_name.split(".")[-1]
|
||||||
|
|
||||||
|
# Fused layers like gate_up_proj or qkv_proj will not be fused
|
||||||
|
# in the safetensors checkpoint. So, we convert the name
|
||||||
|
# from the fused version to unfused + check to make sure that
|
||||||
|
# each shard of the fused layer has the same scheme.
|
||||||
|
if proj_name in FUSED_LAYER_NAME_MAPPING:
|
||||||
|
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name]
|
||||||
|
|
||||||
|
# Convert fused_name --> [shard_names]
|
||||||
|
shard_names = [
|
||||||
|
layer_name.replace(proj_name, shard_proj_name)
|
||||||
|
for shard_proj_name in shard_proj_names
|
||||||
|
]
|
||||||
|
|
||||||
|
# Layer should be ignored if shards are ignored.
|
||||||
|
should_ignore_layer = None
|
||||||
|
for shard_name in shard_names:
|
||||||
|
should_ignore_shard = check_equal_or_regex_match(
|
||||||
|
layer_name=shard_name, targets=ignore)
|
||||||
|
|
||||||
|
# If shard_idx=0, set layer ignore to match shard.
|
||||||
|
if should_ignore_layer is None:
|
||||||
|
should_ignore_layer = should_ignore_shard
|
||||||
|
|
||||||
|
# If shard_idx=1+ confirm scheme matches prior shards.
|
||||||
|
elif should_ignore_shard != should_ignore_layer:
|
||||||
|
raise ValueError(f"Found a different quantization schemes for "
|
||||||
|
f"{shard_proj_names} in {layer_name}. vLLM "
|
||||||
|
"requires all to use the same scheme.")
|
||||||
|
|
||||||
|
# Unfused layers like down_proj and o_proj will match
|
||||||
|
# the safetensors checkpoint already.
|
||||||
|
else:
|
||||||
|
should_ignore_layer = check_equal_or_regex_match(layer_name=layer_name,
|
||||||
|
targets=ignore)
|
||||||
|
|
||||||
|
assert should_ignore_layer is not None
|
||||||
|
return should_ignore_layer
|
||||||
|
|
||||||
|
|
||||||
|
def check_equal_or_regex_match(layer_name: str,
|
||||||
|
targets: Iterable[str]) -> bool:
|
||||||
|
"""
|
||||||
|
Checks whether a layer_name is exactly equal or a regex match for
|
||||||
|
if target starts with 're:' to any target in list.
|
||||||
|
"""
|
||||||
|
for target in targets:
|
||||||
|
if _is_equal_or_regex_match(layer_name, target):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_equal_or_regex_match(value: str,
|
||||||
|
target: str,
|
||||||
|
check_contains: bool = False) -> bool:
|
||||||
|
"""
|
||||||
|
Checks whether a value is exactly equal or a regex match for target
|
||||||
|
if target starts with 're:'. If check_contains is set to True,
|
||||||
|
additionally checks if the target string is contained within the value.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if target.startswith("re:"):
|
||||||
|
pattern = target[3:]
|
||||||
|
if re.match(pattern, value):
|
||||||
|
return True
|
||||||
|
elif check_contains:
|
||||||
|
if target.lower() in value.lower():
|
||||||
|
return True
|
||||||
|
elif target == value:
|
||||||
|
return True
|
||||||
|
return False
|
@ -13,8 +13,6 @@ from vllm.model_executor.layers.fused_moe import FusedMoE
|
|||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
|
||||||
get_compressed_tensors_cache_scale)
|
|
||||||
from vllm.model_executor.layers.sampler import (SamplerOutput,
|
from vllm.model_executor.layers.sampler import (SamplerOutput,
|
||||||
SamplingMetadata, get_sampler)
|
SamplingMetadata, get_sampler)
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
@ -390,12 +388,15 @@ class AriaMoELMModel(LlamaModel):
|
|||||||
# Models trained using ColossalAI may include these tensors in
|
# Models trained using ColossalAI may include these tensors in
|
||||||
# the checkpoint. Skip them.
|
# the checkpoint. Skip them.
|
||||||
continue
|
continue
|
||||||
if scale_name := get_compressed_tensors_cache_scale(name):
|
if (self.quant_config is not None and
|
||||||
# Loading kv cache scales for compressed-tensors quantization
|
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||||
|
# Loading kv cache scales for quark and
|
||||||
|
# compressed-tensors quantization
|
||||||
param = params_dict[scale_name]
|
param = params_dict[scale_name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
loaded_weight = loaded_weight[0]
|
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||||
|
loaded_weight[0])
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
loaded_params.add(scale_name)
|
loaded_params.add(scale_name)
|
||||||
continue
|
continue
|
||||||
|
@ -437,6 +437,20 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
loaded_params: Set[str] = set()
|
loaded_params: Set[str] = set()
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
|
|
||||||
|
if (self.quant_config is not None and
|
||||||
|
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||||
|
# Loading kv cache scales for quark and
|
||||||
|
# compressed-tensors quantization
|
||||||
|
param = params_dict[scale_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||||
|
loaded_weight[0])
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(scale_name)
|
||||||
|
continue
|
||||||
|
|
||||||
for param_name, shard_name, shard_id in stacked_params_mapping:
|
for param_name, shard_name, shard_id in stacked_params_mapping:
|
||||||
if shard_name not in name:
|
if shard_name not in name:
|
||||||
continue
|
continue
|
||||||
|
@ -83,7 +83,7 @@ class DbrxExperts(FusedMoE):
|
|||||||
|
|
||||||
# Define custom weight loader for dbrx model
|
# Define custom weight loader for dbrx model
|
||||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
|
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
|
||||||
weight_name: str):
|
weight_name: str, param_name: str):
|
||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
param_data = param.data
|
param_data = param.data
|
||||||
shard_size = self.intermediate_size
|
shard_size = self.intermediate_size
|
||||||
@ -91,25 +91,37 @@ class DbrxExperts(FusedMoE):
|
|||||||
# DBRX uses GLU for each experts.
|
# DBRX uses GLU for each experts.
|
||||||
# GLU has 3 linear layers: w1, v1 and w2.
|
# GLU has 3 linear layers: w1, v1 and w2.
|
||||||
if weight_name.endswith("w1"):
|
if weight_name.endswith("w1"):
|
||||||
loaded_weight = torch.reshape(
|
if param_name.endswith("weight"):
|
||||||
loaded_weight,
|
loaded_weight = torch.reshape(
|
||||||
[-1, self.intermediate_size * self.tp_size, self.d_model],
|
loaded_weight,
|
||||||
)
|
[-1, self.intermediate_size * self.tp_size, self.d_model],
|
||||||
param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :]
|
)
|
||||||
|
param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :]
|
||||||
|
elif param_name.endswith("weight_scale"):
|
||||||
|
param_data[:, 0] = loaded_weight
|
||||||
|
else:
|
||||||
|
param_data = loaded_weight
|
||||||
if weight_name.endswith("v1"):
|
if weight_name.endswith("v1"):
|
||||||
loaded_weight = torch.reshape(
|
if param_name.endswith("weight"):
|
||||||
loaded_weight,
|
loaded_weight = torch.reshape(
|
||||||
[-1, self.intermediate_size * self.tp_size, self.d_model],
|
loaded_weight,
|
||||||
)
|
[-1, self.intermediate_size * self.tp_size, self.d_model],
|
||||||
param_data[:,
|
)
|
||||||
shard_size:2 * shard_size, :] = loaded_weight[:,
|
param_data[:, shard_size:2 *
|
||||||
shard, :]
|
shard_size, :] = loaded_weight[:, shard, :]
|
||||||
|
elif param_name.endswith("weight_scale"):
|
||||||
|
param_data[:, 1] = loaded_weight
|
||||||
|
else:
|
||||||
|
param_data[:] = loaded_weight
|
||||||
if weight_name.endswith("w2"):
|
if weight_name.endswith("w2"):
|
||||||
loaded_weight = torch.reshape(
|
if param_name.endswith("weight"):
|
||||||
loaded_weight,
|
loaded_weight = torch.reshape(
|
||||||
[-1, self.intermediate_size * self.tp_size, self.d_model],
|
loaded_weight,
|
||||||
).transpose(1, 2)
|
[-1, self.intermediate_size * self.tp_size, self.d_model],
|
||||||
param_data[:] = loaded_weight[:, :, shard]
|
).transpose(1, 2)
|
||||||
|
param_data[:] = loaded_weight[:, :, shard]
|
||||||
|
else:
|
||||||
|
param_data[:] = loaded_weight
|
||||||
|
|
||||||
|
|
||||||
class DbrxMoE(nn.Module):
|
class DbrxMoE(nn.Module):
|
||||||
@ -430,14 +442,29 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str,
|
def load_weights(self, weights: Iterable[Tuple[str,
|
||||||
torch.Tensor]]) -> Set[str]:
|
torch.Tensor]]) -> Set[str]:
|
||||||
|
|
||||||
expert_params_mapping = [(
|
expert_params_mapping = [(
|
||||||
"w13_weight" if weight_name in ["w1", "v1"] else "w2_weight",
|
"w13" if weight_name in ["w1", "v1"] else "w2",
|
||||||
f"mlp.{weight_name}",
|
f"mlp.{weight_name}",
|
||||||
) for weight_name in ["w1", "v1", "w2"]]
|
) for weight_name in ["w1", "v1", "w2"]]
|
||||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||||
loaded_params: Set[str] = set()
|
loaded_params: Set[str] = set()
|
||||||
|
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
|
if (self.quant_config is not None and
|
||||||
|
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||||
|
# Loading kv cache scales for quark and
|
||||||
|
# compressed-tensors quantization
|
||||||
|
param = params_dict[scale_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||||
|
loaded_weight[0])
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(scale_name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if name.endswith(("w1", "w2", "v1")):
|
||||||
|
name = name + "_weight"
|
||||||
for param_name, weight_name in expert_params_mapping:
|
for param_name, weight_name in expert_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
@ -446,8 +473,9 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
|
|||||||
continue
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = param.weight_loader
|
weight_loader = param.weight_loader
|
||||||
weight_loader(param, loaded_weight, weight_name)
|
weight_loader(param, loaded_weight, weight_name, name)
|
||||||
break
|
break
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Remapping the name of FP8 kv-scale.
|
# Remapping the name of FP8 kv-scale.
|
||||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
@ -456,6 +484,9 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
|
|||||||
|
|
||||||
if is_pp_missing_parameter(name, self):
|
if is_pp_missing_parameter(name, self):
|
||||||
continue
|
continue
|
||||||
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
|
@ -39,8 +39,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
|||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
|
||||||
get_compressed_tensors_cache_scale)
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@ -439,6 +437,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
self.transformer = ExaoneModel(
|
self.transformer = ExaoneModel(
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
@ -532,12 +531,15 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
# processed with quantization, LoRA, fine-tuning, etc.
|
# processed with quantization, LoRA, fine-tuning, etc.
|
||||||
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||||
continue
|
continue
|
||||||
if scale_name := get_compressed_tensors_cache_scale(name):
|
if (self.quant_config is not None and
|
||||||
# Loading kv cache scales for compressed-tensors quantization
|
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||||
|
# Loading kv cache scales for quark and
|
||||||
|
# compressed-tensors quantization
|
||||||
param = params_dict[scale_name]
|
param = params_dict[scale_name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
loaded_weight = loaded_weight[0]
|
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||||
|
loaded_weight[0])
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
loaded_params.add(scale_name)
|
loaded_params.add(scale_name)
|
||||||
continue
|
continue
|
||||||
|
@ -31,8 +31,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
|||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
|
||||||
get_compressed_tensors_cache_scale)
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@ -254,6 +252,7 @@ class Gemma2Model(nn.Module):
|
|||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
config.vocab_size,
|
config.vocab_size,
|
||||||
@ -329,7 +328,8 @@ class Gemma2Model(nn.Module):
|
|||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
loaded_params: Set[str] = set()
|
loaded_params: Set[str] = set()
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if scale_name := get_compressed_tensors_cache_scale(name):
|
if (self.quant_config is not None and
|
||||||
|
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||||
# Loading kv cache scales for compressed-tensors quantization
|
# Loading kv cache scales for compressed-tensors quantization
|
||||||
param = params_dict[scale_name]
|
param = params_dict[scale_name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
@ -313,6 +313,20 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
|
|||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "attn.bias" in name or "attn.masked_bias" in name:
|
if "attn.bias" in name or "attn.masked_bias" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if (self.quant_config is not None and
|
||||||
|
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||||
|
# Loading kv cache scales for quark and
|
||||||
|
# compressed-tensors quantization
|
||||||
|
param = params_dict[scale_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||||
|
loaded_weight[0])
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(scale_name)
|
||||||
|
continue
|
||||||
|
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
|
@ -39,8 +39,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
|||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
|
||||||
get_compressed_tensors_cache_scale)
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@ -371,6 +369,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
self.model = GraniteModel(vllm_config=vllm_config,
|
self.model = GraniteModel(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(prefix, "model"))
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
@ -474,12 +473,15 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
# processed with quantization, LoRA, fine-tuning, etc.
|
# processed with quantization, LoRA, fine-tuning, etc.
|
||||||
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||||
continue
|
continue
|
||||||
if scale_name := get_compressed_tensors_cache_scale(name):
|
if (self.quant_config is not None and
|
||||||
# Loading kv cache scales for compressed-tensors quantization
|
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||||
|
# Loading kv cache scales for quark and
|
||||||
|
# compressed-tensors quantization
|
||||||
param = params_dict[scale_name]
|
param = params_dict[scale_name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
loaded_weight = loaded_weight[0]
|
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||||
|
loaded_weight[0])
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
loaded_params.add(scale_name)
|
loaded_params.add(scale_name)
|
||||||
continue
|
continue
|
||||||
|
@ -38,8 +38,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
|||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
|
||||||
get_compressed_tensors_cache_scale)
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@ -306,6 +304,7 @@ class LlamaModel(nn.Module):
|
|||||||
lora_config = vllm_config.lora_config
|
lora_config = vllm_config.lora_config
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.quant_config = quant_config
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
lora_vocab = (lora_config.lora_extra_vocab_size *
|
lora_vocab = (lora_config.lora_extra_vocab_size *
|
||||||
(lora_config.max_loras or 1)) if lora_config else 0
|
(lora_config.max_loras or 1)) if lora_config else 0
|
||||||
@ -396,12 +395,15 @@ class LlamaModel(nn.Module):
|
|||||||
# Models trained using ColossalAI may include these tensors in
|
# Models trained using ColossalAI may include these tensors in
|
||||||
# the checkpoint. Skip them.
|
# the checkpoint. Skip them.
|
||||||
continue
|
continue
|
||||||
if scale_name := get_compressed_tensors_cache_scale(name):
|
if (self.quant_config is not None and
|
||||||
# Loading kv cache scales for compressed-tensors quantization
|
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||||
|
# Loading kv cache scales for quark and
|
||||||
|
# compressed-tensors quantization
|
||||||
param = params_dict[scale_name]
|
param = params_dict[scale_name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
loaded_weight = loaded_weight[0]
|
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||||
|
loaded_weight[0])
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
loaded_params.add(scale_name)
|
loaded_params.add(scale_name)
|
||||||
continue
|
continue
|
||||||
|
@ -347,6 +347,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
lora_config = vllm_config.lora_config
|
lora_config = vllm_config.lora_config
|
||||||
self.config = config
|
self.config = config
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
self.model = MixtralModel(vllm_config=vllm_config,
|
self.model = MixtralModel(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(prefix, "model"))
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
@ -428,6 +429,19 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if (self.quant_config is not None and
|
||||||
|
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||||
|
# Loading kv cache scales for quark and
|
||||||
|
# compressed-tensors quantization
|
||||||
|
param = params_dict[scale_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||||
|
loaded_weight[0])
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(scale_name)
|
||||||
|
continue
|
||||||
|
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
|
@ -1116,6 +1116,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
|
self.quant_config = quant_config
|
||||||
self.vocab_size = config.text_config.vocab_size
|
self.vocab_size = config.text_config.vocab_size
|
||||||
self.hidden_size = config.text_config.hidden_size
|
self.hidden_size = config.text_config.hidden_size
|
||||||
self.max_num_tiles = config.vision_config.max_num_tiles
|
self.max_num_tiles = config.vision_config.max_num_tiles
|
||||||
@ -1429,6 +1430,18 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|||||||
name = name.replace('patch_embedding.weight',
|
name = name.replace('patch_embedding.weight',
|
||||||
'patch_embedding._linear.weight')
|
'patch_embedding._linear.weight')
|
||||||
loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1)
|
loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1)
|
||||||
|
if (self.quant_config is not None and
|
||||||
|
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||||
|
# Loading kv cache scales for quark and
|
||||||
|
# compressed-tensors quantization
|
||||||
|
param = params_dict[scale_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||||
|
loaded_weight[0])
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
updated_params.add(scale_name)
|
||||||
|
continue
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
|
@ -405,6 +405,7 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
self.model = NemotronModel(vllm_config=vllm_config,
|
self.model = NemotronModel(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(prefix, "model"))
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
@ -489,6 +490,18 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
# Models trained using ColossalAI may include these tensors in
|
# Models trained using ColossalAI may include these tensors in
|
||||||
# the checkpoint. Skip them.
|
# the checkpoint. Skip them.
|
||||||
continue
|
continue
|
||||||
|
if (self.quant_config is not None and
|
||||||
|
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||||
|
# Loading kv cache scales for quark and
|
||||||
|
# compressed-tensors quantization
|
||||||
|
param = params_dict[scale_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||||
|
loaded_weight[0])
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(scale_name)
|
||||||
|
continue
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
|
@ -546,6 +546,7 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
lora_config = vllm_config.lora_config
|
lora_config = vllm_config.lora_config
|
||||||
self.config = config
|
self.config = config
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
|
self.quant_config = vllm_config.quant_config
|
||||||
|
|
||||||
self.model = PhiMoEModel(vllm_config=vllm_config,
|
self.model = PhiMoEModel(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(prefix, "model"))
|
prefix=maybe_prefix(prefix, "model"))
|
||||||
@ -623,6 +624,19 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if (self.quant_config is not None and
|
||||||
|
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||||
|
# Loading kv cache scales for quark and
|
||||||
|
# compressed-tensors quantization
|
||||||
|
param = params_dict[scale_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||||
|
loaded_weight[0])
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(scale_name)
|
||||||
|
continue
|
||||||
|
|
||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
|
@ -279,6 +279,7 @@ class Qwen2Model(nn.Module):
|
|||||||
))
|
))
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.quant_config = quant_config
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
@ -364,6 +365,18 @@ class Qwen2Model(nn.Module):
|
|||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
if (self.quant_config is not None and
|
||||||
|
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||||
|
# Loading kv cache scales for quark and
|
||||||
|
# compressed-tensors quantization
|
||||||
|
param = params_dict[scale_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||||
|
loaded_weight[0])
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(scale_name)
|
||||||
|
continue
|
||||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
|
@ -39,8 +39,6 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
|||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
|
||||||
get_compressed_tensors_cache_scale)
|
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@ -409,6 +407,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
lora_config = vllm_config.lora_config
|
lora_config = vllm_config.lora_config
|
||||||
self.config = config
|
self.config = config
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
self.model = SolarModel(
|
self.model = SolarModel(
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
@ -491,12 +490,15 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
|||||||
# Models trained using ColossalAI may include these tensors in
|
# Models trained using ColossalAI may include these tensors in
|
||||||
# the checkpoint. Skip them.
|
# the checkpoint. Skip them.
|
||||||
continue
|
continue
|
||||||
if scale_name := get_compressed_tensors_cache_scale(name):
|
if (self.quant_config is not None and
|
||||||
# Loading kv cache scales for compressed-tensors quantization
|
(scale_name := self.quant_config.get_cache_scale(name))):
|
||||||
|
# Loading kv cache scales for quark and
|
||||||
|
# compressed-tensors quantization
|
||||||
param = params_dict[scale_name]
|
param = params_dict[scale_name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
default_weight_loader)
|
default_weight_loader)
|
||||||
loaded_weight = loaded_weight[0]
|
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
|
||||||
|
loaded_weight[0])
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
loaded_params.add(scale_name)
|
loaded_params.add(scale_name)
|
||||||
continue
|
continue
|
||||||
|
@ -56,8 +56,14 @@ class BasevLLMParameter(Parameter):
|
|||||||
def weight_loader(self):
|
def weight_loader(self):
|
||||||
return self._weight_loader
|
return self._weight_loader
|
||||||
|
|
||||||
|
def _is_1d_and_scalar(self, loaded_weight: torch.Tensor):
|
||||||
|
cond1 = self.data.ndim == 1 and self.data.numel() == 1
|
||||||
|
cond2 = loaded_weight.ndim == 0 and loaded_weight.numel() == 1
|
||||||
|
return (cond1 and cond2)
|
||||||
|
|
||||||
def _assert_and_load(self, loaded_weight: torch.Tensor):
|
def _assert_and_load(self, loaded_weight: torch.Tensor):
|
||||||
assert self.data.shape == loaded_weight.shape
|
assert (self.data.shape == loaded_weight.shape
|
||||||
|
or self._is_1d_and_scalar(loaded_weight))
|
||||||
self.data.copy_(loaded_weight)
|
self.data.copy_(loaded_weight)
|
||||||
|
|
||||||
def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
|
def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||||
|
@ -70,7 +70,7 @@ class RocmPlatform(Platform):
|
|||||||
|
|
||||||
supported_quantization: list[str] = [
|
supported_quantization: list[str] = [
|
||||||
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
|
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
|
||||||
"fbgemm_fp8", "gguf"
|
"fbgemm_fp8", "gguf", "quark"
|
||||||
]
|
]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
Loading…
x
Reference in New Issue
Block a user