[Bugfix] Get available quantization methods from quantization registry (#4098)
This commit is contained in:
parent
66ded03067
commit
53b018edcb
@ -9,6 +9,7 @@ import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
@ -101,7 +102,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--tokenizer', type=str, default=None)
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
choices=['awq', 'gptq', 'squeezellm', None],
|
||||
choices=[*QUANTIZATION_METHODS, None],
|
||||
default=None)
|
||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
||||
parser.add_argument('--input-len', type=int, default=32)
|
||||
|
@ -10,6 +10,8 @@ from tqdm import tqdm
|
||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||
PreTrainedTokenizerBase)
|
||||
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
|
||||
|
||||
def sample_requests(
|
||||
dataset_path: str,
|
||||
@ -267,7 +269,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--tokenizer", type=str, default=None)
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
choices=['awq', 'gptq', 'squeezellm', None],
|
||||
choices=[*QUANTIZATION_METHODS, None],
|
||||
default=None)
|
||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
||||
parser.add_argument("--n",
|
||||
|
@ -16,13 +16,12 @@ from dataclasses import dataclass
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization import (
|
||||
_QUANTIZATION_CONFIG_REGISTRY)
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
marlin_not_supported = (
|
||||
capability < _QUANTIZATION_CONFIG_REGISTRY["marlin"].get_min_capability())
|
||||
marlin_not_supported = (capability <
|
||||
QUANTIZATION_METHODS["marlin"].get_min_capability())
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -9,6 +9,7 @@ from packaging.version import Version
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.transformers_utils.config import get_config, get_hf_text_config
|
||||
from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip,
|
||||
is_neuron)
|
||||
@ -118,8 +119,8 @@ class ModelConfig:
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
|
||||
def _verify_quantization(self) -> None:
|
||||
supported_quantization = ["awq", "gptq", "squeezellm", "marlin"]
|
||||
rocm_not_supported_quantization = ["awq", "marlin"]
|
||||
supported_quantization = [*QUANTIZATION_METHODS]
|
||||
rocm_supported_quantization = ["gptq", "squeezellm"]
|
||||
if self.quantization is not None:
|
||||
self.quantization = self.quantization.lower()
|
||||
|
||||
@ -155,7 +156,7 @@ class ModelConfig:
|
||||
f"Unknown quantization method: {self.quantization}. Must "
|
||||
f"be one of {supported_quantization}.")
|
||||
if is_hip(
|
||||
) and self.quantization in rocm_not_supported_quantization:
|
||||
) and self.quantization not in rocm_supported_quantization:
|
||||
raise ValueError(
|
||||
f"{self.quantization} quantization is currently not "
|
||||
f"supported in ROCm.")
|
||||
|
@ -7,6 +7,7 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
||||
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||
TokenizerPoolConfig, VisionLanguageConfig)
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.utils import str_to_int_tuple
|
||||
|
||||
|
||||
@ -286,7 +287,7 @@ class EngineArgs:
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
type=str,
|
||||
choices=['awq', 'gptq', 'squeezellm', None],
|
||||
choices=[*QUANTIZATION_METHODS, None],
|
||||
default=EngineArgs.quantization,
|
||||
help='Method used to quantize the weights. If '
|
||||
'None, we first check the `quantization_config` '
|
||||
|
@ -7,7 +7,7 @@ from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
||||
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
||||
|
||||
_QUANTIZATION_CONFIG_REGISTRY = {
|
||||
QUANTIZATION_METHODS = {
|
||||
"awq": AWQConfig,
|
||||
"gptq": GPTQConfig,
|
||||
"squeezellm": SqueezeLLMConfig,
|
||||
@ -16,12 +16,13 @@ _QUANTIZATION_CONFIG_REGISTRY = {
|
||||
|
||||
|
||||
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||
if quantization not in _QUANTIZATION_CONFIG_REGISTRY:
|
||||
if quantization not in QUANTIZATION_METHODS:
|
||||
raise ValueError(f"Invalid quantization method: {quantization}")
|
||||
return _QUANTIZATION_CONFIG_REGISTRY[quantization]
|
||||
return QUANTIZATION_METHODS[quantization]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"QuantizationConfig",
|
||||
"get_quantization_config",
|
||||
"QUANTIZATION_METHODS",
|
||||
]
|
||||
|
Loading…
x
Reference in New Issue
Block a user