[Kernel] w4a16
support for compressed-tensors
(#5385)
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
This commit is contained in:
parent
88407532e7
commit
c2637a613b
@ -3,12 +3,13 @@
|
|||||||
Run `pytest tests/quantization/test_compressed_tensors.py`.
|
Run `pytest tests/quantization/test_compressed_tensors.py`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||||
CompressedTensorsLinearMethod, CompressedTensorsW8A8DynamicToken,
|
CompressedTensorsLinearMethod, CompressedTensorsW4A16,
|
||||||
CompressedTensorsW8A8StaticTensor)
|
CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor)
|
||||||
|
|
||||||
|
|
||||||
def test_compressed_tensors_w8a8_static_setup(vllm_runner):
|
def test_compressed_tensors_w8a8_static_setup(vllm_runner):
|
||||||
@ -60,3 +61,25 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner):
|
|||||||
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
||||||
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8DynamicToken)
|
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8DynamicToken)
|
||||||
assert qkv_proj.weight.dtype is torch.int8
|
assert qkv_proj.weight.dtype is torch.int8
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("w4a16_args", [
|
||||||
|
("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None),
|
||||||
|
("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128),
|
||||||
|
])
|
||||||
|
def test_compressed_tensors_w4a16(vllm_runner, w4a16_args):
|
||||||
|
model, strategy, group = w4a16_args
|
||||||
|
with vllm_runner(model) as llm:
|
||||||
|
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||||
|
layer = model.model.layers[0]
|
||||||
|
|
||||||
|
qkv_proj = layer.self_attn.qkv_proj
|
||||||
|
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
||||||
|
assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16)
|
||||||
|
|
||||||
|
assert qkv_proj.scheme.strategy == strategy
|
||||||
|
assert qkv_proj.scheme.group_size == group
|
||||||
|
|
||||||
|
assert qkv_proj.weight_packed.dtype is torch.int32
|
||||||
|
assert qkv_proj.weight_scale.dtype is torch.float16
|
||||||
|
assert qkv_proj.weight_packed.pack_factor == 8
|
||||||
|
@ -7,8 +7,8 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
|||||||
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
CompressedTensorsScheme, CompressedTensorsW8A8DynamicToken,
|
CompressedTensorsScheme, CompressedTensorsW4A16,
|
||||||
CompressedTensorsW8A8StaticTensor)
|
CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||||
QuantizationArgs, QuantizationStrategy, find_first_name_or_class_match)
|
QuantizationArgs, QuantizationStrategy, find_first_name_or_class_match)
|
||||||
|
|
||||||
@ -47,16 +47,27 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
layer_quant_details: Dict[str, Any] = dict()
|
layer_quant_details: Dict[str, Any] = dict()
|
||||||
ignore: List[str] = config.get("ignore", None)
|
ignore: List[str] = config.get("ignore", None)
|
||||||
|
|
||||||
|
# The quant_config has multiple config_groups, each containing
|
||||||
|
# an input_activations key with details about how the activations are
|
||||||
|
# quantized, a weights key indicating how the weights are quantized,
|
||||||
|
# and a list of targets under the `targets` key, dictating which
|
||||||
|
# layers are impacted by the quantization details. The quantization
|
||||||
|
# details follow the structure defined by the QuantizationArgs
|
||||||
|
# pydantic model, which is used to verify the structure of the
|
||||||
|
# quant_config and also store the details for later use.
|
||||||
for key, quant_config in config["config_groups"].items():
|
for key, quant_config in config["config_groups"].items():
|
||||||
targets = quant_config.get("targets")
|
targets = quant_config.get("targets")
|
||||||
for target in targets:
|
for target in targets:
|
||||||
layer_quant_details[target] = {}
|
layer_quant_details[target] = {}
|
||||||
layer_quant_details[target][
|
layer_quant_details[target][
|
||||||
"weight"] = QuantizationArgs.parse_obj(
|
"weights"] = QuantizationArgs.parse_obj(
|
||||||
quant_config.get("weights"))
|
quant_config.get("weights"))
|
||||||
layer_quant_details[target][
|
try:
|
||||||
"input"] = QuantizationArgs.parse_obj(
|
layer_quant_details[target][
|
||||||
quant_config.get("input_activations"))
|
"input_activations"] = QuantizationArgs.parse_obj(
|
||||||
|
quant_config.get("input_activations"))
|
||||||
|
except Exception:
|
||||||
|
layer_quant_details[target]["input_activations"] = None
|
||||||
|
|
||||||
return cls(layer_quant_details=layer_quant_details, ignore=ignore)
|
return cls(layer_quant_details=layer_quant_details, ignore=ignore)
|
||||||
|
|
||||||
@ -86,8 +97,23 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
|
|
||||||
return is_8_bits and is_token_tensor and is_symmetric and is_dynamic
|
return is_8_bits and is_token_tensor and is_symmetric and is_dynamic
|
||||||
|
|
||||||
|
def _is_w4a16(self, weight_quant: BaseModel,
|
||||||
|
input_quant: BaseModel) -> bool:
|
||||||
|
input_quant_none = input_quant is None
|
||||||
|
is_4_bits = weight_quant.num_bits == 4
|
||||||
|
is_symmetric = weight_quant.symmetric
|
||||||
|
is_static = not weight_quant.dynamic
|
||||||
|
|
||||||
|
return is_4_bits and input_quant_none and is_symmetric and is_static
|
||||||
|
|
||||||
def _get_schema(self, weight_quant: BaseModel,
|
def _get_schema(self, weight_quant: BaseModel,
|
||||||
input_quant: BaseModel) -> "CompressedTensorsScheme":
|
input_quant: BaseModel) -> "CompressedTensorsScheme":
|
||||||
|
|
||||||
|
if self._is_w4a16(weight_quant, input_quant):
|
||||||
|
return CompressedTensorsW4A16(num_bits=weight_quant.num_bits,
|
||||||
|
strategy=weight_quant.strategy,
|
||||||
|
group_size=weight_quant.group_size)
|
||||||
|
|
||||||
if self._is_static_tensor_w8a8(weight_quant, input_quant):
|
if self._is_static_tensor_w8a8(weight_quant, input_quant):
|
||||||
return CompressedTensorsW8A8StaticTensor()
|
return CompressedTensorsW8A8StaticTensor()
|
||||||
|
|
||||||
@ -113,8 +139,9 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Could not find quantization details for {layer}.")
|
f"Could not find quantization details for {layer}.")
|
||||||
|
|
||||||
return self._get_schema(weight_quant=layer_quant_details["weight"],
|
return self._get_schema(
|
||||||
input_quant=layer_quant_details["input"])
|
weight_quant=layer_quant_details["weights"],
|
||||||
|
input_quant=layer_quant_details["input_activations"])
|
||||||
|
|
||||||
|
|
||||||
class CompressedTensorsLinearMethod(LinearMethodBase):
|
class CompressedTensorsLinearMethod(LinearMethodBase):
|
||||||
@ -140,6 +167,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
|
|||||||
layer=layer,
|
layer=layer,
|
||||||
input_size_per_partition=input_size_per_partition,
|
input_size_per_partition=input_size_per_partition,
|
||||||
output_partition_sizes=output_partition_sizes,
|
output_partition_sizes=output_partition_sizes,
|
||||||
|
input_size=input_size,
|
||||||
output_size=output_size,
|
output_size=output_size,
|
||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
weight_loader=weight_loader)
|
weight_loader=weight_loader)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401
|
from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401
|
||||||
from .compressed_tensors_unquantized import ( # noqa: F401
|
from .compressed_tensors_unquantized import ( # noqa: F401
|
||||||
CompressedTensorsUnquantized)
|
CompressedTensorsUnquantized)
|
||||||
|
from .compressed_tensors_w4a16 import CompressedTensorsW4A16 # noqa: F401
|
||||||
from .compressed_tensors_w8a8_dynamictoken import ( # noqa: F401, E501
|
from .compressed_tensors_w8a8_dynamictoken import ( # noqa: F401, E501
|
||||||
CompressedTensorsW8A8DynamicToken)
|
CompressedTensorsW8A8DynamicToken)
|
||||||
from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501
|
from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501
|
||||||
|
@ -0,0 +1,168 @@
|
|||||||
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn import Parameter
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
|
CompressedTensorsScheme)
|
||||||
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||||
|
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQMarlinState,
|
||||||
|
marlin_permute_scales)
|
||||||
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
|
||||||
|
__all__ = ["CompressedTensorsW4A16"]
|
||||||
|
|
||||||
|
|
||||||
|
class CompressedTensorsW4A16(CompressedTensorsScheme):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
strategy: str,
|
||||||
|
num_bits: int,
|
||||||
|
group_size: Optional[int] = None):
|
||||||
|
self.num_bits = num_bits
|
||||||
|
self.strategy = strategy
|
||||||
|
self.group_size = group_size
|
||||||
|
|
||||||
|
if self.strategy == "group" and self.group_size is None:
|
||||||
|
raise ValueError(
|
||||||
|
"group_size must be given when using strategy group")
|
||||||
|
|
||||||
|
def create_weights(self, layer: torch.nn.Module, input_size: int,
|
||||||
|
output_partition_sizes: List[int],
|
||||||
|
input_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype, weight_loader: Callable,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
pack_factor = 32 // self.num_bits
|
||||||
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
|
|
||||||
|
if self.group_size is not None:
|
||||||
|
group_size = self.group_size
|
||||||
|
else:
|
||||||
|
group_size = input_size
|
||||||
|
|
||||||
|
weight_scale_dim = None
|
||||||
|
scales_and_zp_size = input_size // group_size
|
||||||
|
|
||||||
|
if (input_size != input_size_per_partition
|
||||||
|
and self.group_size is not None):
|
||||||
|
weight_scale_dim = 1
|
||||||
|
scales_and_zp_size = input_size_per_partition // group_size
|
||||||
|
|
||||||
|
weight = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
output_size_per_partition,
|
||||||
|
input_size_per_partition // pack_factor,
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
set_weight_attrs(
|
||||||
|
weight, {
|
||||||
|
"input_dim": 1,
|
||||||
|
"output_dim": 0,
|
||||||
|
"packed_dim": 1,
|
||||||
|
"pack_factor": pack_factor
|
||||||
|
})
|
||||||
|
set_weight_attrs(weight, {"weight_loader": weight_loader})
|
||||||
|
|
||||||
|
layer.register_parameter("weight_packed", weight)
|
||||||
|
|
||||||
|
weight_scale = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
output_size_per_partition,
|
||||||
|
scales_and_zp_size,
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
|
||||||
|
set_weight_attrs(weight_scale, {
|
||||||
|
"input_dim": weight_scale_dim,
|
||||||
|
"output_dim": 0
|
||||||
|
})
|
||||||
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
|
|
||||||
|
# A 2D array defining the original shape of the weights
|
||||||
|
# before packing
|
||||||
|
weight_shape = Parameter(torch.empty(2, dtype=torch.int64),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
layer.register_parameter("weight_shape", weight_shape)
|
||||||
|
set_weight_attrs(weight_shape, {"weight_loader": weight_loader})
|
||||||
|
|
||||||
|
layer.input_size_per_partition = input_size_per_partition
|
||||||
|
layer.output_size_per_partition = output_size_per_partition
|
||||||
|
|
||||||
|
layer.input_size = input_size
|
||||||
|
layer.marlin_state = GPTQMarlinState.REPACK
|
||||||
|
layer.is_k_full = True
|
||||||
|
layer.group_size = group_size
|
||||||
|
|
||||||
|
max_workspace_size = (
|
||||||
|
output_size_per_partition //
|
||||||
|
GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
|
||||||
|
|
||||||
|
workspace = torch.zeros(max_workspace_size,
|
||||||
|
dtype=torch.int,
|
||||||
|
requires_grad=False)
|
||||||
|
layer.workspace = workspace
|
||||||
|
|
||||||
|
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
|
||||||
|
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||||
|
|
||||||
|
size_m = reshaped_x.shape[0]
|
||||||
|
part_size_n = layer.output_size_per_partition
|
||||||
|
part_size_k = layer.input_size_per_partition
|
||||||
|
|
||||||
|
out_shape = x.shape[:-1] + (part_size_n, )
|
||||||
|
|
||||||
|
if layer.marlin_state == GPTQMarlinState.REPACK:
|
||||||
|
layer.marlin_state = GPTQMarlinState.READY
|
||||||
|
|
||||||
|
# Newly generated tensors need to replace existing tensors that are
|
||||||
|
# already registered as parameters by vLLM (and won't be freed)
|
||||||
|
def replace_tensor(name, new_t):
|
||||||
|
# It is important to use resize_() here since it ensures
|
||||||
|
# the same buffer is reused
|
||||||
|
getattr(layer, name).resize_(new_t.shape)
|
||||||
|
getattr(layer, name).copy_(new_t)
|
||||||
|
del new_t
|
||||||
|
|
||||||
|
cur_device = layer.weight_packed.device
|
||||||
|
|
||||||
|
# Reset g_idx related tensors
|
||||||
|
layer.g_idx = Parameter(torch.empty(0,
|
||||||
|
dtype=torch.int,
|
||||||
|
device=cur_device),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.g_idx_sort_indices = Parameter(torch.empty(
|
||||||
|
0, dtype=torch.int, device=cur_device),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
# Repack weights
|
||||||
|
marlin_qweight = ops.gptq_marlin_repack(
|
||||||
|
layer.weight_packed.t().contiguous(), layer.g_idx_sort_indices,
|
||||||
|
part_size_k, part_size_n, self.num_bits)
|
||||||
|
|
||||||
|
replace_tensor("weight_packed", marlin_qweight)
|
||||||
|
|
||||||
|
# Permute scales
|
||||||
|
scales_size_k = part_size_k
|
||||||
|
scales_size_n = part_size_n
|
||||||
|
|
||||||
|
marlin_scales = marlin_permute_scales(
|
||||||
|
layer.weight_scale.squeeze().t().contiguous(), scales_size_k,
|
||||||
|
scales_size_n, layer.group_size, self.num_bits)
|
||||||
|
replace_tensor("weight_scale", marlin_scales)
|
||||||
|
|
||||||
|
output = ops.gptq_marlin_gemm(reshaped_x, layer.weight_packed,
|
||||||
|
layer.weight_scale, layer.g_idx,
|
||||||
|
layer.g_idx_sort_indices,
|
||||||
|
layer.workspace, self.num_bits, size_m,
|
||||||
|
part_size_n, part_size_k,
|
||||||
|
layer.is_k_full)
|
||||||
|
return output.reshape(out_shape)
|
Loading…
x
Reference in New Issue
Block a user