vllm/tests/quantization/test_register_quantization_config.py
Russell Bryant e489ad7a21
[Misc] Add SPDX-License-Identifier headers to python source files (#12628)
- **Add SPDX license headers to python source files**
- **Check for SPDX headers using pre-commit**

commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745
Author: Russell Bryant <rbryant@redhat.com>
Date:   Fri Jan 31 14:18:24 2025 -0500

    Add SPDX license headers to python source files
    
This commit adds SPDX license headers to python source files as
recommended to
the project by the Linux Foundation. These headers provide a concise way
that is
both human and machine readable for communicating license information
for each
source file. It helps avoid any ambiguity about the license of the code
and can
    also be easily used by tools to help manage license compliance.
    
The Linux Foundation runs license scans against the codebase to help
ensure
    we are in compliance with the licenses of the code we use, including
dependencies. Having these headers in place helps that tool do its job.
    
    More information can be found on the SPDX site:
    
    - https://spdx.dev/learn/handling-license-info/
    
    Signed-off-by: Russell Bryant <rbryant@redhat.com>

commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea
Author: Russell Bryant <rbryant@redhat.com>
Date:   Fri Jan 31 14:36:32 2025 -0500

    Check for SPDX headers using pre-commit
    
    Signed-off-by: Russell Bryant <rbryant@redhat.com>

---------

Signed-off-by: Russell Bryant <rbryant@redhat.com>
2025-02-02 11:58:18 -08:00

119 lines
4.3 KiB
Python

# SPDX-License-Identifier: Apache-2.0
"""Tests register custom quantization config.
See https://github.com/vllm-project/vllm/issues/11926 for more details.
Run `pytest tests/quantization/test_register_quantization_config.py`.
"""
from typing import Any, Dict, List, Optional
import pytest
import torch
import torch.nn.functional as F
from vllm.model_executor.layers.linear import LinearBase # noqa: E501
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import (
get_quantization_config, register_quantization_config)
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig)
class FakeQuantLinearMethod(UnquantizedLinearMethod):
"""Fake quantization linear method for per-token dynamic quantization."""
def __init__(self, num_bits: int = 8) -> None:
"""Initialize the quantization method."""
super().__init__()
self.num_bits = num_bits
def apply(self,
layer: "torch.nn.Module",
x: "torch.Tensor",
bias: Optional["torch.Tensor"] = None) -> "torch.Tensor":
"""Perform fake quantization before the linear layer."""
# Calculate the scales dynamically
max_val = torch.amax(x, dim=(0, -1), keepdims=True)
min_val = torch.amin(x, dim=(0, -1), keepdims=True)
scales = (max_val - min_val) / (2**self.num_bits - 1)
# Fake quantize the input
quant_x = torch.clamp(torch.round(x / scales), -2**(self.num_bits - 1),
2**(self.num_bits - 1) - 1)
dequant_x = quant_x * scales
return F.linear(dequant_x, layer.weight, bias)
@register_quantization_config("custom_quant")
class CustomQuantConfig(QuantizationConfig):
"""Custom quantization config for per-token dynamic fake quantization."""
def __init__(self, num_bits: int = 8) -> None:
"""Initialize the quantization config."""
self.num_bits = num_bits
def get_name(self) -> str:
"""Name of the quantization method."""
return "custom_quant"
def get_supported_act_dtypes(self) -> List["torch.dtype"]:
"""List of supported activation dtypes."""
return [torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
"""Minimum GPU capability to support the quantization method."""
return -1
@staticmethod
def get_config_filenames() -> List[str]:
"""List of filenames to search for in the model directory."""
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "CustomQuantConfig":
"""Create a config class from the model's quantization config."""
return CustomQuantConfig(num_bits=config.get("num_bits", 8))
def get_quant_method(self, layer: "torch.nn.Module",
prefix: str) -> Optional["FakeQuantLinearMethod"]:
"""Get the quantize method to use for the quantized layer."""
if isinstance(layer, LinearBase):
return FakeQuantLinearMethod(num_bits=self.num_bits)
return None
def test_register_quantization_config():
"""Test register custom quantization config."""
# The quantization method `custom_quant` should be registered.
assert get_quantization_config("custom_quant") == CustomQuantConfig
# The quantization method `custom_quant` is already exists,
# should raise an error.
with pytest.raises(ValueError):
register_quantization_config("custom_quant")(CustomQuantConfig)
@pytest.mark.parametrize(argnames="model",
argvalues=[
"meta-llama/Meta-Llama-3-8B-Instruct",
])
def test_custom_quant(vllm_runner, model):
"""Test infer with the custom quantization method."""
with vllm_runner(model_name=model,
quantization="custom_quant",
enforce_eager=True) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
# Check the quantization method is FakeQuantLinearMethod
assert isinstance(qkv_proj.quant_method, FakeQuantLinearMethod)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output