[ Kernel ] Fp8 Channelwise Weight Support (#6487)

This commit is contained in:
Robert Shaw 2024-07-17 23:18:13 -04:00 committed by GitHub
parent b5af8c223c
commit 18fecc3559
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 76 additions and 35 deletions

View File

@ -238,7 +238,8 @@ class ModelConfig:
f"{self.quantization} quantization is currently not "
f"supported in ROCm.")
if (self.quantization
not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin")):
not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin",
"compressed_tensors")):
logger.warning(
"%s quantization is not fully "
"optimized yet. The speed can be slower than "

View File

@ -13,7 +13,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
CompressionFormat, QuantizationArgs, QuantizationStrategy,
QuantizationType, find_first_name_or_class_match)
QuantizationType, find_first_name_or_class_match,
is_activation_quantization_format)
from vllm.platforms import current_platform
@ -132,10 +133,11 @@ class CompressedTensorsConfig(QuantizationConfig):
# Confirm weight scheme is supported.
is_symmetric_weight = weight_quant.symmetric
is_static_weight = not weight_quant.dynamic
is_per_tensor_weight = (
weight_quant.strategy == QuantizationStrategy.TENSOR)
is_per_tensor_or_channel_weight = (weight_quant.strategy in [
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
])
if not (is_symmetric_weight and is_static_weight
and is_per_tensor_weight):
and is_per_tensor_or_channel_weight):
return False
# Dynamic quantization is always supported if weights supported.
@ -167,6 +169,7 @@ class CompressedTensorsConfig(QuantizationConfig):
def _get_schema(self, weight_quant: BaseModel,
input_quant: BaseModel) -> "CompressedTensorsScheme":
# Detect If Mixed Precision
if self._is_wNa16_group_channel(weight_quant, input_quant):
self._check_gptq_and_marlin_can_run()
if (self.quant_format == CompressionFormat.marlin_24.value
@ -182,11 +185,12 @@ class CompressedTensorsConfig(QuantizationConfig):
strategy=weight_quant.strategy,
group_size=weight_quant.group_size)
if (self.quant_format == CompressionFormat.int_quantized.value or
self.quant_format == CompressionFormat.float_quantized.value):
# Detect If Activation Quantization.
if is_activation_quantization_format(self.quant_format):
if self._is_fp8_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8(
input_dynamic=input_quant.dynamic)
strategy=weight_quant.strategy,
is_static_input_scheme=(not input_quant.dynamic))
if self._is_static_tensor_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8(

View File

@ -1,11 +1,15 @@
from typing import Callable, List, Optional
import torch
from torch.nn import Parameter
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationStrategy)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, create_per_tensor_scale_param, cutlass_fp8_supported,
apply_fp8_linear, create_per_channel_scale_param,
create_per_tensor_scale_param, cutlass_fp8_supported,
requantize_with_max_scale)
from vllm.model_executor.utils import set_weight_attrs
@ -14,39 +18,56 @@ __all__ = ["CompressedTensorsW8A8Fp8"]
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def __init__(self, input_dynamic: bool):
self.input_dynamic = input_dynamic
def __init__(self, strategy: str, is_static_input_scheme: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.cutlass_fp8_supported = cutlass_fp8_supported()
# W8A8-Fp8 kernels support only per-tensor and per-channel cases.
# So if we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), we requantize with a single scale.
# On Lovelace, fail for now if channelwise.
# TODO: (@tms) fallback
if (not self.cutlass_fp8_supported
and self.strategy == QuantizationStrategy.CHANNEL):
raise ValueError(
"Channelwise fp8 quantization requires vLLM's custom "
"cutlass kernels, which are not supported on your device."
"Consider quantizing with per tensor scales or upgrading "
"to Hopper.")
def process_weights_after_loading(self, layer) -> None:
# Dequant -> Quant with max scale.
# If per tensor, when we have a fused module (e.g. QKV) with per
# tensor scales (thus N scales being passed to the kernel),
# requantize so we can always run per tensor
if self.strategy == QuantizationStrategy.TENSOR:
max_w_scale, weight = requantize_with_max_scale(
weight=layer.weight,
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths,
)
# Update layer with new values.
layer.weight = torch.nn.Parameter(weight.t(), requires_grad=False)
layer.weight_scale = torch.nn.Parameter(max_w_scale,
requires_grad=False)
if self.input_dynamic:
layer.input_scale = None
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
# If channelwise, scales are already lined up, so just transpose.
elif self.strategy == QuantizationStrategy.CHANNEL:
assert self.cutlass_fp8_supported
weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False)
else:
layer.input_scale = torch.nn.Parameter(layer.input_scale.max(),
raise ValueError(f"Unknown quantization strategy {self.strategy}")
# INPUT SCALE
if self.is_static_input_scheme:
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
else:
layer.input_scale = None
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
del params_dtype
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
@ -63,12 +84,17 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
})
# WEIGHT SCALE
if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = create_per_channel_scale_param(
output_partition_sizes, weight_loader=weight_loader)
else:
assert self.strategy == QuantizationStrategy.TENSOR
weight_scale = create_per_tensor_scale_param(
output_partition_sizes, weight_loader=weight_loader)
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
if not self.input_dynamic:
if self.is_static_input_scheme:
input_scale = create_per_tensor_scale_param(
output_partition_sizes, weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale)

View File

@ -9,6 +9,7 @@ from torch.nn import Module
class CompressionFormat(Enum):
dense = "dense"
sparse_bitmask = "sparse-bitmask"
naive_quantized = "naive-quantized"
float_quantized = "float-quantized"
int_quantized = "int-quantized"
pack_quantized = "pack-quantized"
@ -76,6 +77,15 @@ class QuantizationArgs(BaseModel):
)
def is_activation_quantization_format(format: str) -> bool:
_ACTIVATION_QUANTIZATION_FORMATS = [
CompressionFormat.naive_quantized.value,
CompressionFormat.int_quantized.value,
CompressionFormat.float_quantized.value
]
return format in _ACTIVATION_QUANTIZATION_FORMATS
def find_first_name_or_class_match(
name: str,
module: Module,