[ Kernel ] Fp8 Channelwise Weight Support (#6487)
This commit is contained in:
parent
b5af8c223c
commit
18fecc3559
@ -238,7 +238,8 @@ class ModelConfig:
|
||||
f"{self.quantization} quantization is currently not "
|
||||
f"supported in ROCm.")
|
||||
if (self.quantization
|
||||
not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin")):
|
||||
not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin",
|
||||
"compressed_tensors")):
|
||||
logger.warning(
|
||||
"%s quantization is not fully "
|
||||
"optimized yet. The speed can be slower than "
|
||||
|
@ -13,7 +13,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsWNA16)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
CompressionFormat, QuantizationArgs, QuantizationStrategy,
|
||||
QuantizationType, find_first_name_or_class_match)
|
||||
QuantizationType, find_first_name_or_class_match,
|
||||
is_activation_quantization_format)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@ -132,10 +133,11 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
# Confirm weight scheme is supported.
|
||||
is_symmetric_weight = weight_quant.symmetric
|
||||
is_static_weight = not weight_quant.dynamic
|
||||
is_per_tensor_weight = (
|
||||
weight_quant.strategy == QuantizationStrategy.TENSOR)
|
||||
is_per_tensor_or_channel_weight = (weight_quant.strategy in [
|
||||
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
|
||||
])
|
||||
if not (is_symmetric_weight and is_static_weight
|
||||
and is_per_tensor_weight):
|
||||
and is_per_tensor_or_channel_weight):
|
||||
return False
|
||||
|
||||
# Dynamic quantization is always supported if weights supported.
|
||||
@ -167,6 +169,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
def _get_schema(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> "CompressedTensorsScheme":
|
||||
|
||||
# Detect If Mixed Precision
|
||||
if self._is_wNa16_group_channel(weight_quant, input_quant):
|
||||
self._check_gptq_and_marlin_can_run()
|
||||
if (self.quant_format == CompressionFormat.marlin_24.value
|
||||
@ -182,11 +185,12 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
strategy=weight_quant.strategy,
|
||||
group_size=weight_quant.group_size)
|
||||
|
||||
if (self.quant_format == CompressionFormat.int_quantized.value or
|
||||
self.quant_format == CompressionFormat.float_quantized.value):
|
||||
# Detect If Activation Quantization.
|
||||
if is_activation_quantization_format(self.quant_format):
|
||||
if self._is_fp8_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8Fp8(
|
||||
input_dynamic=input_quant.dynamic)
|
||||
strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=(not input_quant.dynamic))
|
||||
|
||||
if self._is_static_tensor_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8Int8(
|
||||
|
@ -1,11 +1,15 @@
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
QuantizationStrategy)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
apply_fp8_linear, create_per_tensor_scale_param, cutlass_fp8_supported,
|
||||
apply_fp8_linear, create_per_channel_scale_param,
|
||||
create_per_tensor_scale_param, cutlass_fp8_supported,
|
||||
requantize_with_max_scale)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
@ -14,39 +18,56 @@ __all__ = ["CompressedTensorsW8A8Fp8"]
|
||||
|
||||
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
|
||||
def __init__(self, input_dynamic: bool):
|
||||
self.input_dynamic = input_dynamic
|
||||
def __init__(self, strategy: str, is_static_input_scheme: bool):
|
||||
self.strategy = strategy
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||
|
||||
# W8A8-Fp8 kernels support only per-tensor and per-channel cases.
|
||||
# So if we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||
# scales being passed to the kernel), we requantize with a single scale.
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
# Dequant -> Quant with max scale.
|
||||
max_w_scale, weight = requantize_with_max_scale(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
)
|
||||
# On Lovelace, fail for now if channelwise.
|
||||
# TODO: (@tms) fallback
|
||||
if (not self.cutlass_fp8_supported
|
||||
and self.strategy == QuantizationStrategy.CHANNEL):
|
||||
raise ValueError(
|
||||
"Channelwise fp8 quantization requires vLLM's custom "
|
||||
"cutlass kernels, which are not supported on your device."
|
||||
"Consider quantizing with per tensor scales or upgrading "
|
||||
"to Hopper.")
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
# If per tensor, when we have a fused module (e.g. QKV) with per
|
||||
# tensor scales (thus N scales being passed to the kernel),
|
||||
# requantize so we can always run per tensor
|
||||
if self.strategy == QuantizationStrategy.TENSOR:
|
||||
max_w_scale, weight = requantize_with_max_scale(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
)
|
||||
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||
|
||||
# If channelwise, scales are already lined up, so just transpose.
|
||||
elif self.strategy == QuantizationStrategy.CHANNEL:
|
||||
assert self.cutlass_fp8_supported
|
||||
weight = layer.weight
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
|
||||
# Update layer with new values.
|
||||
layer.weight = torch.nn.Parameter(weight.t(), requires_grad=False)
|
||||
layer.weight_scale = torch.nn.Parameter(max_w_scale,
|
||||
requires_grad=False)
|
||||
if self.input_dynamic:
|
||||
layer.input_scale = None
|
||||
else:
|
||||
layer.input_scale = torch.nn.Parameter(layer.input_scale.max(),
|
||||
requires_grad=False)
|
||||
raise ValueError(f"Unknown quantization strategy {self.strategy}")
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
layer.input_scale = Parameter(layer.input_scale.max(),
|
||||
requires_grad=False)
|
||||
else:
|
||||
layer.input_scale = None
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: List[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
del params_dtype
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
@ -63,12 +84,17 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
})
|
||||
|
||||
# WEIGHT SCALE
|
||||
weight_scale = create_per_tensor_scale_param(
|
||||
output_partition_sizes, weight_loader=weight_loader)
|
||||
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight_scale = create_per_channel_scale_param(
|
||||
output_partition_sizes, weight_loader=weight_loader)
|
||||
else:
|
||||
assert self.strategy == QuantizationStrategy.TENSOR
|
||||
weight_scale = create_per_tensor_scale_param(
|
||||
output_partition_sizes, weight_loader=weight_loader)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE
|
||||
if not self.input_dynamic:
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = create_per_tensor_scale_param(
|
||||
output_partition_sizes, weight_loader=weight_loader)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
@ -9,6 +9,7 @@ from torch.nn import Module
|
||||
class CompressionFormat(Enum):
|
||||
dense = "dense"
|
||||
sparse_bitmask = "sparse-bitmask"
|
||||
naive_quantized = "naive-quantized"
|
||||
float_quantized = "float-quantized"
|
||||
int_quantized = "int-quantized"
|
||||
pack_quantized = "pack-quantized"
|
||||
@ -76,6 +77,15 @@ class QuantizationArgs(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
def is_activation_quantization_format(format: str) -> bool:
|
||||
_ACTIVATION_QUANTIZATION_FORMATS = [
|
||||
CompressionFormat.naive_quantized.value,
|
||||
CompressionFormat.int_quantized.value,
|
||||
CompressionFormat.float_quantized.value
|
||||
]
|
||||
return format in _ACTIVATION_QUANTIZATION_FORMATS
|
||||
|
||||
|
||||
def find_first_name_or_class_match(
|
||||
name: str,
|
||||
module: Module,
|
||||
|
Loading…
x
Reference in New Issue
Block a user