[Kernel] compressed-tensors
marlin 24 support (#5435)
This commit is contained in:
parent
9e74d9d003
commit
890d8d960b
@ -9,7 +9,8 @@ import torch
|
|||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||||
CompressedTensorsLinearMethod, CompressedTensorsW4A16,
|
CompressedTensorsLinearMethod, CompressedTensorsW4A16,
|
||||||
CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor)
|
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8DynamicToken,
|
||||||
|
CompressedTensorsW8A8StaticTensor)
|
||||||
|
|
||||||
|
|
||||||
def test_compressed_tensors_w8a8_static_setup(vllm_runner):
|
def test_compressed_tensors_w8a8_static_setup(vllm_runner):
|
||||||
@ -51,8 +52,7 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):
|
|||||||
|
|
||||||
def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner):
|
def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner):
|
||||||
model_path = "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2"
|
model_path = "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2"
|
||||||
with vllm_runner(model_path, enforce_eager=True,
|
with vllm_runner(model_path, dtype=torch.float16) as llm:
|
||||||
dtype=torch.float16) as llm:
|
|
||||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||||
layer = model.model.layers[0]
|
layer = model.model.layers[0]
|
||||||
|
|
||||||
@ -83,3 +83,20 @@ def test_compressed_tensors_w4a16(vllm_runner, w4a16_args):
|
|||||||
assert qkv_proj.weight_packed.dtype is torch.int32
|
assert qkv_proj.weight_packed.dtype is torch.int32
|
||||||
assert qkv_proj.weight_scale.dtype is torch.float16
|
assert qkv_proj.weight_scale.dtype is torch.float16
|
||||||
assert qkv_proj.weight_packed.pack_factor == 8
|
assert qkv_proj.weight_packed.pack_factor == 8
|
||||||
|
|
||||||
|
|
||||||
|
def test_compressed_tensors_w4a16_marlin24(vllm_runner):
|
||||||
|
model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
|
||||||
|
with vllm_runner(model_path) as llm:
|
||||||
|
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||||
|
layer = model.model.layers[0]
|
||||||
|
|
||||||
|
qkv_proj = layer.self_attn.qkv_proj
|
||||||
|
|
||||||
|
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
||||||
|
assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Sparse24)
|
||||||
|
assert qkv_proj.weight_packed.dtype is torch.int32
|
||||||
|
|
||||||
|
sampling_params = SamplingParams()
|
||||||
|
output = llm.generate("Hello world!", sampling_params=sampling_params)
|
||||||
|
assert output
|
||||||
|
@ -8,16 +8,20 @@ from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
|||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
CompressedTensorsScheme, CompressedTensorsW4A16,
|
CompressedTensorsScheme, CompressedTensorsW4A16,
|
||||||
CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor)
|
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8DynamicToken,
|
||||||
|
CompressedTensorsW8A8StaticTensor)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||||
QuantizationArgs, QuantizationStrategy, find_first_name_or_class_match)
|
CompressionFormat, QuantizationArgs, QuantizationStrategy,
|
||||||
|
find_first_name_or_class_match)
|
||||||
|
|
||||||
|
|
||||||
class CompressedTensorsConfig(QuantizationConfig):
|
class CompressedTensorsConfig(QuantizationConfig):
|
||||||
|
|
||||||
def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str]):
|
def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str],
|
||||||
|
quant_format: str):
|
||||||
self.ignore = ignore
|
self.ignore = ignore
|
||||||
self.layer_quant_details = layer_quant_details
|
self.layer_quant_details = layer_quant_details
|
||||||
|
self.quant_format = quant_format
|
||||||
|
|
||||||
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
|
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
|
||||||
return CompressedTensorsLinearMethod(self)
|
return CompressedTensorsLinearMethod(self)
|
||||||
@ -46,6 +50,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
|
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
|
||||||
layer_quant_details: Dict[str, Any] = dict()
|
layer_quant_details: Dict[str, Any] = dict()
|
||||||
ignore: List[str] = config.get("ignore", None)
|
ignore: List[str] = config.get("ignore", None)
|
||||||
|
quant_format: str = config.get("format", None)
|
||||||
|
|
||||||
# The quant_config has multiple config_groups, each containing
|
# The quant_config has multiple config_groups, each containing
|
||||||
# an input_activations key with details about how the activations are
|
# an input_activations key with details about how the activations are
|
||||||
@ -69,7 +74,9 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
except Exception:
|
except Exception:
|
||||||
layer_quant_details[target]["input_activations"] = None
|
layer_quant_details[target]["input_activations"] = None
|
||||||
|
|
||||||
return cls(layer_quant_details=layer_quant_details, ignore=ignore)
|
return cls(layer_quant_details=layer_quant_details,
|
||||||
|
ignore=ignore,
|
||||||
|
quant_format=quant_format)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config_filenames(cls) -> List[str]:
|
def get_config_filenames(cls) -> List[str]:
|
||||||
@ -110,17 +117,26 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
input_quant: BaseModel) -> "CompressedTensorsScheme":
|
input_quant: BaseModel) -> "CompressedTensorsScheme":
|
||||||
|
|
||||||
if self._is_w4a16(weight_quant, input_quant):
|
if self._is_w4a16(weight_quant, input_quant):
|
||||||
return CompressedTensorsW4A16(num_bits=weight_quant.num_bits,
|
if self.quant_format == CompressionFormat.marlin_24.value:
|
||||||
strategy=weight_quant.strategy,
|
return CompressedTensorsW4A16Sparse24(
|
||||||
group_size=weight_quant.group_size)
|
strategy=weight_quant.strategy,
|
||||||
|
num_bits=weight_quant.num_bits,
|
||||||
|
group_size=weight_quant.group_size)
|
||||||
|
if self.quant_format == CompressionFormat.pack_quantized.value:
|
||||||
|
return CompressedTensorsW4A16(
|
||||||
|
num_bits=weight_quant.num_bits,
|
||||||
|
strategy=weight_quant.strategy,
|
||||||
|
group_size=weight_quant.group_size)
|
||||||
|
|
||||||
if self._is_static_tensor_w8a8(weight_quant, input_quant):
|
if self.quant_format == CompressionFormat.int_quantized.value:
|
||||||
return CompressedTensorsW8A8StaticTensor()
|
if self._is_static_tensor_w8a8(weight_quant, input_quant):
|
||||||
|
return CompressedTensorsW8A8StaticTensor()
|
||||||
|
|
||||||
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
|
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
|
||||||
return CompressedTensorsW8A8DynamicToken()
|
return CompressedTensorsW8A8DynamicToken()
|
||||||
|
|
||||||
raise NotImplementedError("Scheme not supported.")
|
raise NotImplementedError(
|
||||||
|
"No compressed-tensors compatible scheme was found.")
|
||||||
|
|
||||||
def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme":
|
def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme":
|
||||||
|
|
||||||
@ -165,9 +181,9 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
|
|||||||
scheme = self.quantization_config.get_scheme(layer=layer)
|
scheme = self.quantization_config.get_scheme(layer=layer)
|
||||||
scheme.create_weights(
|
scheme.create_weights(
|
||||||
layer=layer,
|
layer=layer,
|
||||||
|
input_size=input_size,
|
||||||
input_size_per_partition=input_size_per_partition,
|
input_size_per_partition=input_size_per_partition,
|
||||||
output_partition_sizes=output_partition_sizes,
|
output_partition_sizes=output_partition_sizes,
|
||||||
input_size=input_size,
|
|
||||||
output_size=output_size,
|
output_size=output_size,
|
||||||
params_dtype=params_dtype,
|
params_dtype=params_dtype,
|
||||||
weight_loader=weight_loader)
|
weight_loader=weight_loader)
|
||||||
|
@ -2,6 +2,8 @@ from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401
|
|||||||
from .compressed_tensors_unquantized import ( # noqa: F401
|
from .compressed_tensors_unquantized import ( # noqa: F401
|
||||||
CompressedTensorsUnquantized)
|
CompressedTensorsUnquantized)
|
||||||
from .compressed_tensors_w4a16 import CompressedTensorsW4A16 # noqa: F401
|
from .compressed_tensors_w4a16 import CompressedTensorsW4A16 # noqa: F401
|
||||||
|
from .compressed_tensors_w4a16_24 import ( # noqa: F401
|
||||||
|
CompressedTensorsW4A16Sparse24)
|
||||||
from .compressed_tensors_w8a8_dynamictoken import ( # noqa: F401, E501
|
from .compressed_tensors_w8a8_dynamictoken import ( # noqa: F401, E501
|
||||||
CompressedTensorsW8A8DynamicToken)
|
CompressedTensorsW8A8DynamicToken)
|
||||||
from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501
|
from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501
|
||||||
|
@ -0,0 +1,134 @@
|
|||||||
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn import Parameter
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
|
CompressedTensorsScheme)
|
||||||
|
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||||
|
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N)
|
||||||
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
|
||||||
|
__all__ = ["CompressedTensorsW4A16Sparse24"]
|
||||||
|
|
||||||
|
|
||||||
|
class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
strategy: str,
|
||||||
|
num_bits: int,
|
||||||
|
group_size: Optional[int] = None):
|
||||||
|
self.strategy = strategy
|
||||||
|
self.group_size = group_size
|
||||||
|
self.num_bits = num_bits
|
||||||
|
self.tile_size = 16
|
||||||
|
|
||||||
|
if self.strategy == "group" and self.group_size is None:
|
||||||
|
raise ValueError(
|
||||||
|
"group_size must be given when using strategy group")
|
||||||
|
|
||||||
|
def create_weights(self, layer: torch.nn.Module, input_size: int,
|
||||||
|
output_partition_sizes: List[int],
|
||||||
|
input_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype, weight_loader: Callable,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
pack_factor = 32 // self.num_bits
|
||||||
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
|
|
||||||
|
qweight = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
input_size_per_partition // self.tile_size // 2,
|
||||||
|
output_size_per_partition * self.tile_size // pack_factor,
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
qweight,
|
||||||
|
{
|
||||||
|
"input_dim": 0,
|
||||||
|
"output_dim": 1,
|
||||||
|
"packed_dim": 1,
|
||||||
|
"pack_factor": pack_factor,
|
||||||
|
"marlin_tile_size": self.tile_size,
|
||||||
|
"weight_loader": weight_loader
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
layer.register_parameter("weight_packed", qweight)
|
||||||
|
|
||||||
|
input_groups = (1 if self.group_size is None else
|
||||||
|
input_size_per_partition // self.group_size)
|
||||||
|
|
||||||
|
scales = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
input_groups,
|
||||||
|
output_size_per_partition,
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
scales,
|
||||||
|
{
|
||||||
|
"output_dim": 1,
|
||||||
|
"input_dim": None if input_groups == 1 else 0,
|
||||||
|
"weight_loader": weight_loader
|
||||||
|
},
|
||||||
|
)
|
||||||
|
layer.register_parameter("scale_packed", scales)
|
||||||
|
|
||||||
|
weight_shape = Parameter(torch.empty(2, dtype=torch.int64),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
layer.register_parameter("weight_shape", weight_shape)
|
||||||
|
set_weight_attrs(weight_shape, {"weight_loader": weight_loader})
|
||||||
|
|
||||||
|
meta = Parameter(
|
||||||
|
torch.empty(
|
||||||
|
input_size_per_partition // 8 // 2 // 2,
|
||||||
|
output_size_per_partition * 2,
|
||||||
|
dtype=torch.int16,
|
||||||
|
),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
set_weight_attrs(
|
||||||
|
meta,
|
||||||
|
{
|
||||||
|
"input_dim": 0,
|
||||||
|
"packed_dim": 1,
|
||||||
|
"pack_factor": 1,
|
||||||
|
"output_dim": 1,
|
||||||
|
"marlin_tile_size": 2,
|
||||||
|
"weight_loader": weight_loader
|
||||||
|
},
|
||||||
|
)
|
||||||
|
layer.register_parameter("meta", meta)
|
||||||
|
|
||||||
|
max_workspace_size = (
|
||||||
|
output_size_per_partition //
|
||||||
|
GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL
|
||||||
|
workspace = Parameter(torch.zeros(max_workspace_size, dtype=torch.int),
|
||||||
|
requires_grad=False)
|
||||||
|
layer.workspace = workspace
|
||||||
|
|
||||||
|
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
|
||||||
|
qweight = layer.weight_packed
|
||||||
|
meta = layer.meta
|
||||||
|
scales = layer.scale_packed
|
||||||
|
workspace = layer.workspace
|
||||||
|
|
||||||
|
x_2d = x.view(-1, x.shape[-1])
|
||||||
|
|
||||||
|
size_m = x_2d.shape[0]
|
||||||
|
size_k = x_2d.shape[1]
|
||||||
|
size_n = scales.shape[1]
|
||||||
|
|
||||||
|
output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
|
||||||
|
workspace, self.num_bits, size_m,
|
||||||
|
size_n, size_k)
|
||||||
|
|
||||||
|
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
|
||||||
|
return output
|
@ -6,6 +6,14 @@ from pydantic import BaseModel, Field
|
|||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
|
|
||||||
|
|
||||||
|
class CompressionFormat(Enum):
|
||||||
|
dense = "dense"
|
||||||
|
sparse_bitmask = "sparse-bitmask"
|
||||||
|
int_quantized = "int-quantized"
|
||||||
|
pack_quantized = "pack-quantized"
|
||||||
|
marlin_24 = "marlin-24"
|
||||||
|
|
||||||
|
|
||||||
class QuantizationType(str, Enum):
|
class QuantizationType(str, Enum):
|
||||||
"""
|
"""
|
||||||
Enum storing quantization type options
|
Enum storing quantization type options
|
||||||
|
Loading…
x
Reference in New Issue
Block a user