[Kernel] compressed-tensors marlin 24 support (#5435)

This commit is contained in:
Dipika Sikka 2024-06-17 12:32:48 -04:00 committed by GitHub
parent 9e74d9d003
commit 890d8d960b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 193 additions and 16 deletions

View File

@ -9,7 +9,8 @@ import torch
from vllm import SamplingParams
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod, CompressedTensorsW4A16,
CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor)
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8DynamicToken,
CompressedTensorsW8A8StaticTensor)
def test_compressed_tensors_w8a8_static_setup(vllm_runner):
@ -51,8 +52,7 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):
def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner):
model_path = "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2"
with vllm_runner(model_path, enforce_eager=True,
dtype=torch.float16) as llm:
with vllm_runner(model_path, dtype=torch.float16) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
@ -83,3 +83,20 @@ def test_compressed_tensors_w4a16(vllm_runner, w4a16_args):
assert qkv_proj.weight_packed.dtype is torch.int32
assert qkv_proj.weight_scale.dtype is torch.float16
assert qkv_proj.weight_packed.pack_factor == 8
def test_compressed_tensors_w4a16_marlin24(vllm_runner):
model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
with vllm_runner(model_path) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Sparse24)
assert qkv_proj.weight_packed.dtype is torch.int32
sampling_params = SamplingParams()
output = llm.generate("Hello world!", sampling_params=sampling_params)
assert output

View File

@ -8,16 +8,20 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme, CompressedTensorsW4A16,
CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor)
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8DynamicToken,
CompressedTensorsW8A8StaticTensor)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationArgs, QuantizationStrategy, find_first_name_or_class_match)
CompressionFormat, QuantizationArgs, QuantizationStrategy,
find_first_name_or_class_match)
class CompressedTensorsConfig(QuantizationConfig):
def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str]):
def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str],
quant_format: str):
self.ignore = ignore
self.layer_quant_details = layer_quant_details
self.quant_format = quant_format
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self)
@ -46,6 +50,7 @@ class CompressedTensorsConfig(QuantizationConfig):
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
layer_quant_details: Dict[str, Any] = dict()
ignore: List[str] = config.get("ignore", None)
quant_format: str = config.get("format", None)
# The quant_config has multiple config_groups, each containing
# an input_activations key with details about how the activations are
@ -69,7 +74,9 @@ class CompressedTensorsConfig(QuantizationConfig):
except Exception:
layer_quant_details[target]["input_activations"] = None
return cls(layer_quant_details=layer_quant_details, ignore=ignore)
return cls(layer_quant_details=layer_quant_details,
ignore=ignore,
quant_format=quant_format)
@classmethod
def get_config_filenames(cls) -> List[str]:
@ -110,17 +117,26 @@ class CompressedTensorsConfig(QuantizationConfig):
input_quant: BaseModel) -> "CompressedTensorsScheme":
if self._is_w4a16(weight_quant, input_quant):
return CompressedTensorsW4A16(num_bits=weight_quant.num_bits,
if self.quant_format == CompressionFormat.marlin_24.value:
return CompressedTensorsW4A16Sparse24(
strategy=weight_quant.strategy,
num_bits=weight_quant.num_bits,
group_size=weight_quant.group_size)
if self.quant_format == CompressionFormat.pack_quantized.value:
return CompressedTensorsW4A16(
num_bits=weight_quant.num_bits,
strategy=weight_quant.strategy,
group_size=weight_quant.group_size)
if self.quant_format == CompressionFormat.int_quantized.value:
if self._is_static_tensor_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8StaticTensor()
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8DynamicToken()
raise NotImplementedError("Scheme not supported.")
raise NotImplementedError(
"No compressed-tensors compatible scheme was found.")
def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme":
@ -165,9 +181,9 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
scheme = self.quantization_config.get_scheme(layer=layer)
scheme.create_weights(
layer=layer,
input_size=input_size,
input_size_per_partition=input_size_per_partition,
output_partition_sizes=output_partition_sizes,
input_size=input_size,
output_size=output_size,
params_dtype=params_dtype,
weight_loader=weight_loader)

View File

@ -2,6 +2,8 @@ from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401
from .compressed_tensors_unquantized import ( # noqa: F401
CompressedTensorsUnquantized)
from .compressed_tensors_w4a16 import CompressedTensorsW4A16 # noqa: F401
from .compressed_tensors_w4a16_24 import ( # noqa: F401
CompressedTensorsW4A16Sparse24)
from .compressed_tensors_w8a8_dynamictoken import ( # noqa: F401, E501
CompressedTensorsW8A8DynamicToken)
from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501

View File

@ -0,0 +1,134 @@
from typing import Callable, List, Optional
import torch
from torch.nn import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N)
from vllm.model_executor.utils import set_weight_attrs
__all__ = ["CompressedTensorsW4A16Sparse24"]
class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
def __init__(self,
strategy: str,
num_bits: int,
group_size: Optional[int] = None):
self.strategy = strategy
self.group_size = group_size
self.num_bits = num_bits
self.tile_size = 16
if self.strategy == "group" and self.group_size is None:
raise ValueError(
"group_size must be given when using strategy group")
def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
pack_factor = 32 // self.num_bits
output_size_per_partition = sum(output_partition_sizes)
qweight = Parameter(
torch.empty(
input_size_per_partition // self.tile_size // 2,
output_size_per_partition * self.tile_size // pack_factor,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight,
{
"input_dim": 0,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": pack_factor,
"marlin_tile_size": self.tile_size,
"weight_loader": weight_loader
},
)
layer.register_parameter("weight_packed", qweight)
input_groups = (1 if self.group_size is None else
input_size_per_partition // self.group_size)
scales = Parameter(
torch.empty(
input_groups,
output_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
"output_dim": 1,
"input_dim": None if input_groups == 1 else 0,
"weight_loader": weight_loader
},
)
layer.register_parameter("scale_packed", scales)
weight_shape = Parameter(torch.empty(2, dtype=torch.int64),
requires_grad=False)
layer.register_parameter("weight_shape", weight_shape)
set_weight_attrs(weight_shape, {"weight_loader": weight_loader})
meta = Parameter(
torch.empty(
input_size_per_partition // 8 // 2 // 2,
output_size_per_partition * 2,
dtype=torch.int16,
),
requires_grad=False,
)
set_weight_attrs(
meta,
{
"input_dim": 0,
"packed_dim": 1,
"pack_factor": 1,
"output_dim": 1,
"marlin_tile_size": 2,
"weight_loader": weight_loader
},
)
layer.register_parameter("meta", meta)
max_workspace_size = (
output_size_per_partition //
GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL
workspace = Parameter(torch.zeros(max_workspace_size, dtype=torch.int),
requires_grad=False)
layer.workspace = workspace
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
qweight = layer.weight_packed
meta = layer.meta
scales = layer.scale_packed
workspace = layer.workspace
x_2d = x.view(-1, x.shape[-1])
size_m = x_2d.shape[0]
size_k = x_2d.shape[1]
size_n = scales.shape[1]
output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
workspace, self.num_bits, size_m,
size_n, size_k)
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
return output

View File

@ -6,6 +6,14 @@ from pydantic import BaseModel, Field
from torch.nn import Module
class CompressionFormat(Enum):
dense = "dense"
sparse_bitmask = "sparse-bitmask"
int_quantized = "int-quantized"
pack_quantized = "pack-quantized"
marlin_24 = "marlin-24"
class QuantizationType(str, Enum):
"""
Enum storing quantization type options