Signed-off-by: drisspg <drisspguessous@gmail.com>
This commit is contained in:
Driss Guessous 2025-04-07 16:39:28 -07:00 committed by GitHub
parent 24f1c01e0f
commit 652907b354
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 191 additions and 1 deletions

View File

@ -18,4 +18,5 @@ int8
fp8
quark
quantized_kvcache
torchao
:::

View File

@ -0,0 +1,34 @@
# TorchAO
TorchAO is an architecture optimization library for PyTorch, it provides high performance dtypes, optimization techniques and kernels for inference and training, featuring composability with native PyTorch features like torch.compile, FSDP etc.. Some benchmark numbers can be found [here](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks).
We recommend installing the latest torchao nightly with
```console
# Install the latest TorchAO nightly build
# Choose the CUDA version that matches your system (cu126, cu128, etc.)
pip install --pre torchao>=10.0.0 --index-url https://download.pytorch.org/whl/nightly/cu126
```
## Quantizing HuggingFace Models
You can quantize your own huggingface model with torchao, e.g. [transformers](https://huggingface.co/docs/transformers/main/en/quantization/torchao) and [diffusers](https://huggingface.co/docs/diffusers/en/quantization/torchao), and save the checkpoint to huggingface hub like [this](https://huggingface.co/jerryzh168/llama3-8b-int8wo) with the following example code:
```Python
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int8WeightOnlyConfig
model_name = "meta-llama/Meta-Llama-3-8B"
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
hub_repo = # YOUR HUB REPO ID
tokenizer.push_to_hub(hub_repo)
quantized_model.push_to_hub(hub_repo, safe_serialization=False)
```
Alternatively, you can use the TorchAO Quantization space for quantizing models with a simple UI.
See: https://huggingface.co/spaces/medmekk/TorchAO_Quantization

View File

@ -0,0 +1,25 @@
# SPDX-License-Identifier: Apache-2.0
import importlib.metadata
import importlib.util
import pytest
DTYPE = ["bfloat16"]
TORCHAO_AVAILABLE = importlib.util.find_spec("torchao") is not None
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
def test_pre_quantized_model(vllm_runner):
with vllm_runner("drisspg/float8_dynamic_act_float8_weight-opt-125m",
quantization="torchao",
dtype="bfloat16",
enforce_eager=True) as llm:
output = llm.generate_greedy(["The capital of France is"],
max_tokens=32)
assert output
print(output)
if __name__ == "__main__":
pytest.main([__file__])

View File

@ -31,7 +31,8 @@ QUANTIZATION_METHODS: List[str] = [
"neuron_quant",
"ipex",
"quark",
"moe_wna16"
"moe_wna16",
"torchao",
]
# The customized quantization methods which will be added to this dict.
@ -103,6 +104,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
from .neuron_quant import NeuronQuantConfig
from .ptpc_fp8 import PTPCFp8Config
from .qqq import QQQConfig
from .torchao import TorchAOConfig
from .tpu_int8 import Int8TpuConfig
method_to_config: Dict[str, Type[QuantizationConfig]] = {
@ -132,6 +134,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
"ipex": IPEXConfig,
"quark": QuarkConfig,
"moe_wna16": MoeWNA16Config,
"torchao": TorchAOConfig,
}
# Update the `method_to_config` with customized quantization methods.
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)

View File

@ -0,0 +1,127 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
class TorchAOConfig(QuantizationConfig):
"""Config class for torchao."""
def __init__(self, torchao_config) -> None:
self.torchao_config = torchao_config
def __repr__(self) -> str:
return f"TorchAOConfig({self.torchao_config})"
def get_name(self) -> str:
return "torchao"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.float32, torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 75
@staticmethod
def get_config_filenames() -> List[str]:
return ["config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "TorchAOConfig":
"""Create the quant config from an hf model config"""
try:
from torchao.core.config import config_from_dict
except ImportError as err:
raise ImportError(
"Please install torchao>=0.10.0 via "
"`pip install torchao>=0.10.0` to use torchao quantization."
) from err
hf_config = cls.get_from_keys_or(config, ["quant_type"], None)
assert hf_config is not None, "quant_type must be specified"
assert (len(hf_config) == 1 and "default" in hf_config
), "Expected only one key 'default' in quant_type dictionary"
quant_type = hf_config["default"]
ao_config = config_from_dict(quant_type)
return cls(ao_config)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["TorchAOLinearMethod"]:
if isinstance(layer, LinearBase):
return TorchAOLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
def torchao_quantize_param_data(param: torch.Tensor,
torchao_config: Any) -> torch.nn.Parameter:
"""Quantize a Tensor with torchao quantization specified by torchao_config
Args:
`param`: weight parameter of the linear module
`torchao_config`: type of quantization and their arguments we want to
use to quantize the Tensor
"""
from torchao.core.config import AOBaseConfig
from torchao.quantization import quantize_
assert isinstance(torchao_config, AOBaseConfig)
dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
dummy_linear.weight = param
quantize_(dummy_linear, torchao_config)
return dummy_linear.weight
class TorchAOLinearMethod(LinearMethodBase):
"""Linear method for torchao.
Args:
torchao_config: The torchao quantization config, a string
that encodes the type of quantization and all relevant arguments.
"""
def __init__(self, quant_config: TorchAOConfig):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
weight = Parameter(
torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
weight = torchao_quantize_param_data(weight,
self.quant_config.torchao_config)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return F.linear(x, layer.weight, bias)