parent
24f1c01e0f
commit
652907b354
@ -18,4 +18,5 @@ int8
|
||||
fp8
|
||||
quark
|
||||
quantized_kvcache
|
||||
torchao
|
||||
:::
|
||||
|
34
docs/source/features/quantization/torchao.md
Normal file
34
docs/source/features/quantization/torchao.md
Normal file
@ -0,0 +1,34 @@
|
||||
# TorchAO
|
||||
|
||||
TorchAO is an architecture optimization library for PyTorch, it provides high performance dtypes, optimization techniques and kernels for inference and training, featuring composability with native PyTorch features like torch.compile, FSDP etc.. Some benchmark numbers can be found [here](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks).
|
||||
|
||||
We recommend installing the latest torchao nightly with
|
||||
|
||||
```console
|
||||
# Install the latest TorchAO nightly build
|
||||
# Choose the CUDA version that matches your system (cu126, cu128, etc.)
|
||||
pip install --pre torchao>=10.0.0 --index-url https://download.pytorch.org/whl/nightly/cu126
|
||||
```
|
||||
|
||||
## Quantizing HuggingFace Models
|
||||
You can quantize your own huggingface model with torchao, e.g. [transformers](https://huggingface.co/docs/transformers/main/en/quantization/torchao) and [diffusers](https://huggingface.co/docs/diffusers/en/quantization/torchao), and save the checkpoint to huggingface hub like [this](https://huggingface.co/jerryzh168/llama3-8b-int8wo) with the following example code:
|
||||
|
||||
```Python
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from torchao.quantization import Int8WeightOnlyConfig
|
||||
|
||||
model_name = "meta-llama/Meta-Llama-3-8B"
|
||||
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", quantization_config=quantization_config)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
|
||||
hub_repo = # YOUR HUB REPO ID
|
||||
tokenizer.push_to_hub(hub_repo)
|
||||
quantized_model.push_to_hub(hub_repo, safe_serialization=False)
|
||||
```
|
||||
|
||||
Alternatively, you can use the TorchAO Quantization space for quantizing models with a simple UI.
|
||||
See: https://huggingface.co/spaces/medmekk/TorchAO_Quantization
|
25
tests/quantization/test_torchao.py
Normal file
25
tests/quantization/test_torchao.py
Normal file
@ -0,0 +1,25 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import importlib.metadata
|
||||
import importlib.util
|
||||
|
||||
import pytest
|
||||
|
||||
DTYPE = ["bfloat16"]
|
||||
|
||||
TORCHAO_AVAILABLE = importlib.util.find_spec("torchao") is not None
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
|
||||
def test_pre_quantized_model(vllm_runner):
|
||||
with vllm_runner("drisspg/float8_dynamic_act_float8_weight-opt-125m",
|
||||
quantization="torchao",
|
||||
dtype="bfloat16",
|
||||
enforce_eager=True) as llm:
|
||||
output = llm.generate_greedy(["The capital of France is"],
|
||||
max_tokens=32)
|
||||
assert output
|
||||
print(output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
@ -31,7 +31,8 @@ QUANTIZATION_METHODS: List[str] = [
|
||||
"neuron_quant",
|
||||
"ipex",
|
||||
"quark",
|
||||
"moe_wna16"
|
||||
"moe_wna16",
|
||||
"torchao",
|
||||
]
|
||||
|
||||
# The customized quantization methods which will be added to this dict.
|
||||
@ -103,6 +104,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||
from .neuron_quant import NeuronQuantConfig
|
||||
from .ptpc_fp8 import PTPCFp8Config
|
||||
from .qqq import QQQConfig
|
||||
from .torchao import TorchAOConfig
|
||||
from .tpu_int8 import Int8TpuConfig
|
||||
|
||||
method_to_config: Dict[str, Type[QuantizationConfig]] = {
|
||||
@ -132,6 +134,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
||||
"ipex": IPEXConfig,
|
||||
"quark": QuarkConfig,
|
||||
"moe_wna16": MoeWNA16Config,
|
||||
"torchao": TorchAOConfig,
|
||||
}
|
||||
# Update the `method_to_config` with customized quantization methods.
|
||||
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
|
||||
|
127
vllm/model_executor/layers/quantization/torchao.py
Normal file
127
vllm/model_executor/layers/quantization/torchao.py
Normal file
@ -0,0 +1,127 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
|
||||
class TorchAOConfig(QuantizationConfig):
|
||||
"""Config class for torchao."""
|
||||
|
||||
def __init__(self, torchao_config) -> None:
|
||||
self.torchao_config = torchao_config
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"TorchAOConfig({self.torchao_config})"
|
||||
|
||||
def get_name(self) -> str:
|
||||
return "torchao"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
return [torch.float32, torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
return ["config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "TorchAOConfig":
|
||||
"""Create the quant config from an hf model config"""
|
||||
try:
|
||||
from torchao.core.config import config_from_dict
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please install torchao>=0.10.0 via "
|
||||
"`pip install torchao>=0.10.0` to use torchao quantization."
|
||||
) from err
|
||||
|
||||
hf_config = cls.get_from_keys_or(config, ["quant_type"], None)
|
||||
assert hf_config is not None, "quant_type must be specified"
|
||||
assert (len(hf_config) == 1 and "default" in hf_config
|
||||
), "Expected only one key 'default' in quant_type dictionary"
|
||||
quant_type = hf_config["default"]
|
||||
ao_config = config_from_dict(quant_type)
|
||||
return cls(ao_config)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["TorchAOLinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return TorchAOLinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
|
||||
def torchao_quantize_param_data(param: torch.Tensor,
|
||||
torchao_config: Any) -> torch.nn.Parameter:
|
||||
"""Quantize a Tensor with torchao quantization specified by torchao_config
|
||||
|
||||
Args:
|
||||
`param`: weight parameter of the linear module
|
||||
`torchao_config`: type of quantization and their arguments we want to
|
||||
use to quantize the Tensor
|
||||
"""
|
||||
from torchao.core.config import AOBaseConfig
|
||||
from torchao.quantization import quantize_
|
||||
assert isinstance(torchao_config, AOBaseConfig)
|
||||
dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
|
||||
dummy_linear.weight = param
|
||||
quantize_(dummy_linear, torchao_config)
|
||||
return dummy_linear.weight
|
||||
|
||||
|
||||
class TorchAOLinearMethod(LinearMethodBase):
|
||||
"""Linear method for torchao.
|
||||
|
||||
Args:
|
||||
torchao_config: The torchao quantization config, a string
|
||||
that encodes the type of quantization and all relevant arguments.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: TorchAOConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
weight = Parameter(
|
||||
torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
weight = torchao_quantize_param_data(weight,
|
||||
self.quant_config.torchao_config)
|
||||
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
return F.linear(x, layer.weight, bias)
|
Loading…
x
Reference in New Issue
Block a user