diff --git a/vllm/config.py b/vllm/config.py index fd48cc3a..de5d0402 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -197,13 +197,17 @@ class ModelConfig: def _parse_quant_hf_config(self): quant_cfg = getattr(self.hf_config, "quantization_config", None) if quant_cfg is None: - # compress-tensors uses a "compression_config" key + # compressed-tensors uses a "compression_config" key quant_cfg = getattr(self.hf_config, "compression_config", None) return quant_cfg def _verify_quantization(self) -> None: supported_quantization = [*QUANTIZATION_METHODS] rocm_supported_quantization = ["gptq", "squeezellm"] + optimized_quantization_methods = [ + "fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin", + "fbgemm_fp8", "compressed_tensors", "compressed-tensors" + ] if self.quantization is not None: self.quantization = self.quantization.lower() @@ -242,9 +246,7 @@ class ModelConfig: raise ValueError( f"{self.quantization} quantization is currently not " f"supported in ROCm.") - if (self.quantization - not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin", - "awq_marlin", "fbgemm_fp8", "compressed_tensors")): + if self.quantization not in optimized_quantization_methods: logger.warning( "%s quantization is not fully " "optimized yet. The speed can be slower than "