[ Kernel ] Fp8 Channelwise Weight Support (#6487)
This commit is contained in:
parent
b5af8c223c
commit
18fecc3559
@ -238,7 +238,8 @@ class ModelConfig:
|
|||||||
f"{self.quantization} quantization is currently not "
|
f"{self.quantization} quantization is currently not "
|
||||||
f"supported in ROCm.")
|
f"supported in ROCm.")
|
||||||
if (self.quantization
|
if (self.quantization
|
||||||
not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin")):
|
not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin",
|
||||||
|
"compressed_tensors")):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"%s quantization is not fully "
|
"%s quantization is not fully "
|
||||||
"optimized yet. The speed can be slower than "
|
"optimized yet. The speed can be slower than "
|
||||||
|
@ -13,7 +13,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
|||||||
CompressedTensorsWNA16)
|
CompressedTensorsWNA16)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||||
CompressionFormat, QuantizationArgs, QuantizationStrategy,
|
CompressionFormat, QuantizationArgs, QuantizationStrategy,
|
||||||
QuantizationType, find_first_name_or_class_match)
|
QuantizationType, find_first_name_or_class_match,
|
||||||
|
is_activation_quantization_format)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
@ -132,10 +133,11 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
# Confirm weight scheme is supported.
|
# Confirm weight scheme is supported.
|
||||||
is_symmetric_weight = weight_quant.symmetric
|
is_symmetric_weight = weight_quant.symmetric
|
||||||
is_static_weight = not weight_quant.dynamic
|
is_static_weight = not weight_quant.dynamic
|
||||||
is_per_tensor_weight = (
|
is_per_tensor_or_channel_weight = (weight_quant.strategy in [
|
||||||
weight_quant.strategy == QuantizationStrategy.TENSOR)
|
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
|
||||||
|
])
|
||||||
if not (is_symmetric_weight and is_static_weight
|
if not (is_symmetric_weight and is_static_weight
|
||||||
and is_per_tensor_weight):
|
and is_per_tensor_or_channel_weight):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Dynamic quantization is always supported if weights supported.
|
# Dynamic quantization is always supported if weights supported.
|
||||||
@ -167,6 +169,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
def _get_schema(self, weight_quant: BaseModel,
|
def _get_schema(self, weight_quant: BaseModel,
|
||||||
input_quant: BaseModel) -> "CompressedTensorsScheme":
|
input_quant: BaseModel) -> "CompressedTensorsScheme":
|
||||||
|
|
||||||
|
# Detect If Mixed Precision
|
||||||
if self._is_wNa16_group_channel(weight_quant, input_quant):
|
if self._is_wNa16_group_channel(weight_quant, input_quant):
|
||||||
self._check_gptq_and_marlin_can_run()
|
self._check_gptq_and_marlin_can_run()
|
||||||
if (self.quant_format == CompressionFormat.marlin_24.value
|
if (self.quant_format == CompressionFormat.marlin_24.value
|
||||||
@ -182,11 +185,12 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
strategy=weight_quant.strategy,
|
strategy=weight_quant.strategy,
|
||||||
group_size=weight_quant.group_size)
|
group_size=weight_quant.group_size)
|
||||||
|
|
||||||
if (self.quant_format == CompressionFormat.int_quantized.value or
|
# Detect If Activation Quantization.
|
||||||
self.quant_format == CompressionFormat.float_quantized.value):
|
if is_activation_quantization_format(self.quant_format):
|
||||||
if self._is_fp8_w8a8(weight_quant, input_quant):
|
if self._is_fp8_w8a8(weight_quant, input_quant):
|
||||||
return CompressedTensorsW8A8Fp8(
|
return CompressedTensorsW8A8Fp8(
|
||||||
input_dynamic=input_quant.dynamic)
|
strategy=weight_quant.strategy,
|
||||||
|
is_static_input_scheme=(not input_quant.dynamic))
|
||||||
|
|
||||||
if self._is_static_tensor_w8a8(weight_quant, input_quant):
|
if self._is_static_tensor_w8a8(weight_quant, input_quant):
|
||||||
return CompressedTensorsW8A8Int8(
|
return CompressedTensorsW8A8Int8(
|
||||||
|
@ -1,11 +1,15 @@
|
|||||||
from typing import Callable, List, Optional
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch.nn import Parameter
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||||
CompressedTensorsScheme)
|
CompressedTensorsScheme)
|
||||||
|
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||||
|
QuantizationStrategy)
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
apply_fp8_linear, create_per_tensor_scale_param, cutlass_fp8_supported,
|
apply_fp8_linear, create_per_channel_scale_param,
|
||||||
|
create_per_tensor_scale_param, cutlass_fp8_supported,
|
||||||
requantize_with_max_scale)
|
requantize_with_max_scale)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
|
||||||
@ -14,39 +18,56 @@ __all__ = ["CompressedTensorsW8A8Fp8"]
|
|||||||
|
|
||||||
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||||
|
|
||||||
def __init__(self, input_dynamic: bool):
|
def __init__(self, strategy: str, is_static_input_scheme: bool):
|
||||||
self.input_dynamic = input_dynamic
|
self.strategy = strategy
|
||||||
|
self.is_static_input_scheme = is_static_input_scheme
|
||||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||||
|
|
||||||
# W8A8-Fp8 kernels support only per-tensor and per-channel cases.
|
# On Lovelace, fail for now if channelwise.
|
||||||
# So if we have a fused module (QKV, MLP) with per tensor scales (thus N
|
# TODO: (@tms) fallback
|
||||||
# scales being passed to the kernel), we requantize with a single scale.
|
if (not self.cutlass_fp8_supported
|
||||||
|
and self.strategy == QuantizationStrategy.CHANNEL):
|
||||||
|
raise ValueError(
|
||||||
|
"Channelwise fp8 quantization requires vLLM's custom "
|
||||||
|
"cutlass kernels, which are not supported on your device."
|
||||||
|
"Consider quantizing with per tensor scales or upgrading "
|
||||||
|
"to Hopper.")
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer) -> None:
|
def process_weights_after_loading(self, layer) -> None:
|
||||||
# Dequant -> Quant with max scale.
|
# If per tensor, when we have a fused module (e.g. QKV) with per
|
||||||
|
# tensor scales (thus N scales being passed to the kernel),
|
||||||
|
# requantize so we can always run per tensor
|
||||||
|
if self.strategy == QuantizationStrategy.TENSOR:
|
||||||
max_w_scale, weight = requantize_with_max_scale(
|
max_w_scale, weight = requantize_with_max_scale(
|
||||||
weight=layer.weight,
|
weight=layer.weight,
|
||||||
weight_scale=layer.weight_scale,
|
weight_scale=layer.weight_scale,
|
||||||
logical_widths=layer.logical_widths,
|
logical_widths=layer.logical_widths,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update layer with new values.
|
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||||
layer.weight = torch.nn.Parameter(weight.t(), requires_grad=False)
|
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||||
layer.weight_scale = torch.nn.Parameter(max_w_scale,
|
|
||||||
requires_grad=False)
|
# If channelwise, scales are already lined up, so just transpose.
|
||||||
if self.input_dynamic:
|
elif self.strategy == QuantizationStrategy.CHANNEL:
|
||||||
layer.input_scale = None
|
assert self.cutlass_fp8_supported
|
||||||
|
weight = layer.weight
|
||||||
|
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
layer.input_scale = torch.nn.Parameter(layer.input_scale.max(),
|
raise ValueError(f"Unknown quantization strategy {self.strategy}")
|
||||||
|
|
||||||
|
# INPUT SCALE
|
||||||
|
if self.is_static_input_scheme:
|
||||||
|
layer.input_scale = Parameter(layer.input_scale.max(),
|
||||||
requires_grad=False)
|
requires_grad=False)
|
||||||
|
else:
|
||||||
|
layer.input_scale = None
|
||||||
|
|
||||||
def create_weights(self, layer: torch.nn.Module,
|
def create_weights(self, layer: torch.nn.Module,
|
||||||
output_partition_sizes: List[int],
|
output_partition_sizes: List[int],
|
||||||
input_size_per_partition: int,
|
input_size_per_partition: int,
|
||||||
params_dtype: torch.dtype, weight_loader: Callable,
|
params_dtype: torch.dtype, weight_loader: Callable,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
del params_dtype
|
|
||||||
|
|
||||||
output_size_per_partition = sum(output_partition_sizes)
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
layer.logical_widths = output_partition_sizes
|
layer.logical_widths = output_partition_sizes
|
||||||
|
|
||||||
@ -63,12 +84,17 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
})
|
})
|
||||||
|
|
||||||
# WEIGHT SCALE
|
# WEIGHT SCALE
|
||||||
|
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||||
|
weight_scale = create_per_channel_scale_param(
|
||||||
|
output_partition_sizes, weight_loader=weight_loader)
|
||||||
|
else:
|
||||||
|
assert self.strategy == QuantizationStrategy.TENSOR
|
||||||
weight_scale = create_per_tensor_scale_param(
|
weight_scale = create_per_tensor_scale_param(
|
||||||
output_partition_sizes, weight_loader=weight_loader)
|
output_partition_sizes, weight_loader=weight_loader)
|
||||||
layer.register_parameter("weight_scale", weight_scale)
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
|
|
||||||
# INPUT SCALE
|
# INPUT SCALE
|
||||||
if not self.input_dynamic:
|
if self.is_static_input_scheme:
|
||||||
input_scale = create_per_tensor_scale_param(
|
input_scale = create_per_tensor_scale_param(
|
||||||
output_partition_sizes, weight_loader=weight_loader)
|
output_partition_sizes, weight_loader=weight_loader)
|
||||||
layer.register_parameter("input_scale", input_scale)
|
layer.register_parameter("input_scale", input_scale)
|
||||||
|
@ -9,6 +9,7 @@ from torch.nn import Module
|
|||||||
class CompressionFormat(Enum):
|
class CompressionFormat(Enum):
|
||||||
dense = "dense"
|
dense = "dense"
|
||||||
sparse_bitmask = "sparse-bitmask"
|
sparse_bitmask = "sparse-bitmask"
|
||||||
|
naive_quantized = "naive-quantized"
|
||||||
float_quantized = "float-quantized"
|
float_quantized = "float-quantized"
|
||||||
int_quantized = "int-quantized"
|
int_quantized = "int-quantized"
|
||||||
pack_quantized = "pack-quantized"
|
pack_quantized = "pack-quantized"
|
||||||
@ -76,6 +77,15 @@ class QuantizationArgs(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_activation_quantization_format(format: str) -> bool:
|
||||||
|
_ACTIVATION_QUANTIZATION_FORMATS = [
|
||||||
|
CompressionFormat.naive_quantized.value,
|
||||||
|
CompressionFormat.int_quantized.value,
|
||||||
|
CompressionFormat.float_quantized.value
|
||||||
|
]
|
||||||
|
return format in _ACTIVATION_QUANTIZATION_FORMATS
|
||||||
|
|
||||||
|
|
||||||
def find_first_name_or_class_match(
|
def find_first_name_or_class_match(
|
||||||
name: str,
|
name: str,
|
||||||
module: Module,
|
module: Module,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user