[Bugfix] Get available quantization methods from quantization registry (#4098)
This commit is contained in:
parent
66ded03067
commit
53b018edcb
@ -9,6 +9,7 @@ import torch
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
|
|
||||||
|
|
||||||
def main(args: argparse.Namespace):
|
def main(args: argparse.Namespace):
|
||||||
@ -101,7 +102,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--tokenizer', type=str, default=None)
|
parser.add_argument('--tokenizer', type=str, default=None)
|
||||||
parser.add_argument('--quantization',
|
parser.add_argument('--quantization',
|
||||||
'-q',
|
'-q',
|
||||||
choices=['awq', 'gptq', 'squeezellm', None],
|
choices=[*QUANTIZATION_METHODS, None],
|
||||||
default=None)
|
default=None)
|
||||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
|
||||||
parser.add_argument('--input-len', type=int, default=32)
|
parser.add_argument('--input-len', type=int, default=32)
|
||||||
|
@ -10,6 +10,8 @@ from tqdm import tqdm
|
|||||||
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
from transformers import (AutoModelForCausalLM, AutoTokenizer,
|
||||||
PreTrainedTokenizerBase)
|
PreTrainedTokenizerBase)
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
|
|
||||||
|
|
||||||
def sample_requests(
|
def sample_requests(
|
||||||
dataset_path: str,
|
dataset_path: str,
|
||||||
@ -267,7 +269,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--tokenizer", type=str, default=None)
|
parser.add_argument("--tokenizer", type=str, default=None)
|
||||||
parser.add_argument('--quantization',
|
parser.add_argument('--quantization',
|
||||||
'-q',
|
'-q',
|
||||||
choices=['awq', 'gptq', 'squeezellm', None],
|
choices=[*QUANTIZATION_METHODS, None],
|
||||||
default=None)
|
default=None)
|
||||||
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
|
||||||
parser.add_argument("--n",
|
parser.add_argument("--n",
|
||||||
|
@ -16,13 +16,12 @@ from dataclasses import dataclass
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization import (
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
_QUANTIZATION_CONFIG_REGISTRY)
|
|
||||||
|
|
||||||
capability = torch.cuda.get_device_capability()
|
capability = torch.cuda.get_device_capability()
|
||||||
capability = capability[0] * 10 + capability[1]
|
capability = capability[0] * 10 + capability[1]
|
||||||
marlin_not_supported = (
|
marlin_not_supported = (capability <
|
||||||
capability < _QUANTIZATION_CONFIG_REGISTRY["marlin"].get_min_capability())
|
QUANTIZATION_METHODS["marlin"].get_min_capability())
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -9,6 +9,7 @@ from packaging.version import Version
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
from vllm.transformers_utils.config import get_config, get_hf_text_config
|
from vllm.transformers_utils.config import get_config, get_hf_text_config
|
||||||
from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip,
|
from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip,
|
||||||
is_neuron)
|
is_neuron)
|
||||||
@ -118,8 +119,8 @@ class ModelConfig:
|
|||||||
self.tokenizer_mode = tokenizer_mode
|
self.tokenizer_mode = tokenizer_mode
|
||||||
|
|
||||||
def _verify_quantization(self) -> None:
|
def _verify_quantization(self) -> None:
|
||||||
supported_quantization = ["awq", "gptq", "squeezellm", "marlin"]
|
supported_quantization = [*QUANTIZATION_METHODS]
|
||||||
rocm_not_supported_quantization = ["awq", "marlin"]
|
rocm_supported_quantization = ["gptq", "squeezellm"]
|
||||||
if self.quantization is not None:
|
if self.quantization is not None:
|
||||||
self.quantization = self.quantization.lower()
|
self.quantization = self.quantization.lower()
|
||||||
|
|
||||||
@ -155,7 +156,7 @@ class ModelConfig:
|
|||||||
f"Unknown quantization method: {self.quantization}. Must "
|
f"Unknown quantization method: {self.quantization}. Must "
|
||||||
f"be one of {supported_quantization}.")
|
f"be one of {supported_quantization}.")
|
||||||
if is_hip(
|
if is_hip(
|
||||||
) and self.quantization in rocm_not_supported_quantization:
|
) and self.quantization not in rocm_supported_quantization:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{self.quantization} quantization is currently not "
|
f"{self.quantization} quantization is currently not "
|
||||||
f"supported in ROCm.")
|
f"supported in ROCm.")
|
||||||
|
@ -7,6 +7,7 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
|||||||
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
|
EngineConfig, LoadConfig, LoRAConfig, ModelConfig,
|
||||||
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
ParallelConfig, SchedulerConfig, SpeculativeConfig,
|
||||||
TokenizerPoolConfig, VisionLanguageConfig)
|
TokenizerPoolConfig, VisionLanguageConfig)
|
||||||
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
from vllm.utils import str_to_int_tuple
|
from vllm.utils import str_to_int_tuple
|
||||||
|
|
||||||
|
|
||||||
@ -286,7 +287,7 @@ class EngineArgs:
|
|||||||
parser.add_argument('--quantization',
|
parser.add_argument('--quantization',
|
||||||
'-q',
|
'-q',
|
||||||
type=str,
|
type=str,
|
||||||
choices=['awq', 'gptq', 'squeezellm', None],
|
choices=[*QUANTIZATION_METHODS, None],
|
||||||
default=EngineArgs.quantization,
|
default=EngineArgs.quantization,
|
||||||
help='Method used to quantize the weights. If '
|
help='Method used to quantize the weights. If '
|
||||||
'None, we first check the `quantization_config` '
|
'None, we first check the `quantization_config` '
|
||||||
|
@ -7,7 +7,7 @@ from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
|||||||
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
||||||
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
||||||
|
|
||||||
_QUANTIZATION_CONFIG_REGISTRY = {
|
QUANTIZATION_METHODS = {
|
||||||
"awq": AWQConfig,
|
"awq": AWQConfig,
|
||||||
"gptq": GPTQConfig,
|
"gptq": GPTQConfig,
|
||||||
"squeezellm": SqueezeLLMConfig,
|
"squeezellm": SqueezeLLMConfig,
|
||||||
@ -16,12 +16,13 @@ _QUANTIZATION_CONFIG_REGISTRY = {
|
|||||||
|
|
||||||
|
|
||||||
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||||
if quantization not in _QUANTIZATION_CONFIG_REGISTRY:
|
if quantization not in QUANTIZATION_METHODS:
|
||||||
raise ValueError(f"Invalid quantization method: {quantization}")
|
raise ValueError(f"Invalid quantization method: {quantization}")
|
||||||
return _QUANTIZATION_CONFIG_REGISTRY[quantization]
|
return QUANTIZATION_METHODS[quantization]
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"QuantizationConfig",
|
"QuantizationConfig",
|
||||||
"get_quantization_config",
|
"get_quantization_config",
|
||||||
|
"QUANTIZATION_METHODS",
|
||||||
]
|
]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user