[Bugfix] Get available quantization methods from quantization registry (#4098)

This commit is contained in:
Michael Goin 2024-04-18 03:21:55 -04:00 committed by GitHub
parent 66ded03067
commit 53b018edcb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 18 additions and 13 deletions

View File

@ -9,6 +9,7 @@ import torch
from tqdm import tqdm from tqdm import tqdm
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
@ -101,7 +102,7 @@ if __name__ == '__main__':
parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization', parser.add_argument('--quantization',
'-q', '-q',
choices=['awq', 'gptq', 'squeezellm', None], choices=[*QUANTIZATION_METHODS, None],
default=None) default=None)
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32) parser.add_argument('--input-len', type=int, default=32)

View File

@ -10,6 +10,8 @@ from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer, from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase) PreTrainedTokenizerBase)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
def sample_requests( def sample_requests(
dataset_path: str, dataset_path: str,
@ -267,7 +269,7 @@ if __name__ == "__main__":
parser.add_argument("--tokenizer", type=str, default=None) parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization', parser.add_argument('--quantization',
'-q', '-q',
choices=['awq', 'gptq', 'squeezellm', None], choices=[*QUANTIZATION_METHODS, None],
default=None) default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n", parser.add_argument("--n",

View File

@ -16,13 +16,12 @@ from dataclasses import dataclass
import pytest import pytest
import torch import torch
from vllm.model_executor.layers.quantization import ( from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
_QUANTIZATION_CONFIG_REGISTRY)
capability = torch.cuda.get_device_capability() capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1] capability = capability[0] * 10 + capability[1]
marlin_not_supported = ( marlin_not_supported = (capability <
capability < _QUANTIZATION_CONFIG_REGISTRY["marlin"].get_min_capability()) QUANTIZATION_METHODS["marlin"].get_min_capability())
@dataclass @dataclass

View File

@ -9,6 +9,7 @@ from packaging.version import Version
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.transformers_utils.config import get_config, get_hf_text_config from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip, from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip,
is_neuron) is_neuron)
@ -118,8 +119,8 @@ class ModelConfig:
self.tokenizer_mode = tokenizer_mode self.tokenizer_mode = tokenizer_mode
def _verify_quantization(self) -> None: def _verify_quantization(self) -> None:
supported_quantization = ["awq", "gptq", "squeezellm", "marlin"] supported_quantization = [*QUANTIZATION_METHODS]
rocm_not_supported_quantization = ["awq", "marlin"] rocm_supported_quantization = ["gptq", "squeezellm"]
if self.quantization is not None: if self.quantization is not None:
self.quantization = self.quantization.lower() self.quantization = self.quantization.lower()
@ -155,7 +156,7 @@ class ModelConfig:
f"Unknown quantization method: {self.quantization}. Must " f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}.") f"be one of {supported_quantization}.")
if is_hip( if is_hip(
) and self.quantization in rocm_not_supported_quantization: ) and self.quantization not in rocm_supported_quantization:
raise ValueError( raise ValueError(
f"{self.quantization} quantization is currently not " f"{self.quantization} quantization is currently not "
f"supported in ROCm.") f"supported in ROCm.")

View File

@ -7,6 +7,7 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
EngineConfig, LoadConfig, LoRAConfig, ModelConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, SpeculativeConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig,
TokenizerPoolConfig, VisionLanguageConfig) TokenizerPoolConfig, VisionLanguageConfig)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import str_to_int_tuple from vllm.utils import str_to_int_tuple
@ -286,7 +287,7 @@ class EngineArgs:
parser.add_argument('--quantization', parser.add_argument('--quantization',
'-q', '-q',
type=str, type=str,
choices=['awq', 'gptq', 'squeezellm', None], choices=[*QUANTIZATION_METHODS, None],
default=EngineArgs.quantization, default=EngineArgs.quantization,
help='Method used to quantize the weights. If ' help='Method used to quantize the weights. If '
'None, we first check the `quantization_config` ' 'None, we first check the `quantization_config` '

View File

@ -7,7 +7,7 @@ from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
_QUANTIZATION_CONFIG_REGISTRY = { QUANTIZATION_METHODS = {
"awq": AWQConfig, "awq": AWQConfig,
"gptq": GPTQConfig, "gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig, "squeezellm": SqueezeLLMConfig,
@ -16,12 +16,13 @@ _QUANTIZATION_CONFIG_REGISTRY = {
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
if quantization not in _QUANTIZATION_CONFIG_REGISTRY: if quantization not in QUANTIZATION_METHODS:
raise ValueError(f"Invalid quantization method: {quantization}") raise ValueError(f"Invalid quantization method: {quantization}")
return _QUANTIZATION_CONFIG_REGISTRY[quantization] return QUANTIZATION_METHODS[quantization]
__all__ = [ __all__ = [
"QuantizationConfig", "QuantizationConfig",
"get_quantization_config", "get_quantization_config",
"QUANTIZATION_METHODS",
] ]