[Misc] Add compressed-tensors to optimized quant list (#7006)

This commit is contained in:
Michael Goin 2024-07-31 17:40:44 -04:00 committed by GitHub
parent 35e9c12bfa
commit a0dce9383a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -197,13 +197,17 @@ class ModelConfig:
def _parse_quant_hf_config(self): def _parse_quant_hf_config(self):
quant_cfg = getattr(self.hf_config, "quantization_config", None) quant_cfg = getattr(self.hf_config, "quantization_config", None)
if quant_cfg is None: if quant_cfg is None:
# compress-tensors uses a "compression_config" key # compressed-tensors uses a "compression_config" key
quant_cfg = getattr(self.hf_config, "compression_config", None) quant_cfg = getattr(self.hf_config, "compression_config", None)
return quant_cfg return quant_cfg
def _verify_quantization(self) -> None: def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS] supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = ["gptq", "squeezellm"] rocm_supported_quantization = ["gptq", "squeezellm"]
optimized_quantization_methods = [
"fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin",
"fbgemm_fp8", "compressed_tensors", "compressed-tensors"
]
if self.quantization is not None: if self.quantization is not None:
self.quantization = self.quantization.lower() self.quantization = self.quantization.lower()
@ -242,9 +246,7 @@ class ModelConfig:
raise ValueError( raise ValueError(
f"{self.quantization} quantization is currently not " f"{self.quantization} quantization is currently not "
f"supported in ROCm.") f"supported in ROCm.")
if (self.quantization if self.quantization not in optimized_quantization_methods:
not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin",
"awq_marlin", "fbgemm_fp8", "compressed_tensors")):
logger.warning( logger.warning(
"%s quantization is not fully " "%s quantization is not fully "
"optimized yet. The speed can be slower than " "optimized yet. The speed can be slower than "