[ Misc ] fp8-marlin channelwise via compressed-tensors (#6524)

Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
Robert Shaw 2024-07-25 09:46:04 -07:00 committed by GitHub
parent b75e314fff
commit 889da130e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 219 additions and 49 deletions

View File

@ -0,0 +1,11 @@
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen2-1.5B-Instruct-FP8W8 -b auto -l 1000 -f 5 -t 1
model_name: "nm-testing/Qwen2-1.5B-Instruct-FP8W8"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.578
- name: "exact_match,flexible-extract"
value: 0.585
limit: 1000
num_fewshot: 5

View File

@ -5,3 +5,4 @@ Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
Qwen2-1.5B-Instruct-FP8W8.yaml

View File

@ -10,7 +10,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
CompressedTensorsScheme, CompressedTensorsUnquantized,
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8, CompressedTensorsWNA16)
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
CompressionFormat, QuantizationArgs, QuantizationStrategy,
QuantizationType, find_matched_target, is_activation_quantization_format,
@ -100,14 +101,18 @@ class CompressedTensorsConfig(QuantizationConfig):
def get_config_filenames(cls) -> List[str]:
return []
def _check_scheme_supported(self, min_capability: int):
def _check_scheme_supported(self,
min_capability: int,
error: bool = True) -> bool:
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < min_capability:
supported = capability >= min_capability
if error and not supported:
raise RuntimeError(
"Quantization scheme is not supported for ",
f"the current GPU. Min capability: {min_capability}. ",
f"Current capability: {capability}.")
return supported
def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
@ -170,6 +175,29 @@ class CompressedTensorsConfig(QuantizationConfig):
# All conditions satisfied.
return True
def _is_fp8_w8a16(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
# Confirm weights quantized.
if weight_quant is None:
return False
# Confirm we have floating points.
if weight_quant.type != QuantizationType.FLOAT:
return False
# Confirm weight scheme is supported.
is_symmetric_weight = weight_quant.symmetric
is_static_weight = not weight_quant.dynamic
is_per_tensor_or_channel_weight = (weight_quant.strategy in [
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
])
if not (is_symmetric_weight and is_static_weight
and is_per_tensor_or_channel_weight):
return False
# All conditions satisfied.
return True
def _is_wNa16_group_channel(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
input_quant_none = input_quant is None
@ -204,9 +232,23 @@ class CompressedTensorsConfig(QuantizationConfig):
# Detect If Activation Quantization.
if is_activation_quantization_format(self.quant_format):
if self._is_fp8_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8(
is_fp8_w8a8_supported = self._check_scheme_supported(
CompressedTensorsW8A8Fp8.get_min_capability(), error=False)
if is_fp8_w8a8_supported:
return CompressedTensorsW8A8Fp8(
strategy=weight_quant.strategy,
is_static_input_scheme=(not input_quant.dynamic))
else:
return CompressedTensorsW8A16Fp8(
strategy=weight_quant.strategy,
is_static_input_scheme=(input_quant
and not input_quant.dynamic))
if self._is_fp8_w8a16(weight_quant, input_quant):
return CompressedTensorsW8A16Fp8(
strategy=weight_quant.strategy,
is_static_input_scheme=(not input_quant.dynamic))
is_static_input_scheme=(input_quant
and not input_quant.dynamic))
if self._is_static_tensor_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8(
@ -257,11 +299,10 @@ class CompressedTensorsConfig(QuantizationConfig):
targets=self.target_scheme_map.keys())
# Find the quant_scheme
scheme = self.target_scheme_map[matched_target]
return self._get_scheme_from_parts(
weight_quant=scheme["weights"],
input_quant=scheme["input_activations"])
scheme_dict = self.target_scheme_map[matched_target]
scheme = self._get_scheme_from_parts(
weight_quant=scheme_dict["weights"],
input_quant=scheme_dict["input_activations"])
# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)

View File

@ -4,6 +4,7 @@ from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
CompressedTensorsW4A16Sparse24)
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS,
CompressedTensorsWNA16)
@ -11,6 +12,7 @@ __all__ = [
"CompressedTensorsScheme",
"CompressedTensorsUnquantized",
"CompressedTensorsWNA16",
"CompressedTensorsW8A16Fp8",
"CompressedTensorsW4A16Sparse24",
"CompressedTensorsW8A8Int8",
"CompressedTensorsW8A8Fp8",

View File

@ -12,8 +12,9 @@ class CompressedTensorsScheme(ABC):
of different quantization schemes supported by CompressedTensors.
"""
@classmethod
@abstractmethod
def get_min_capability(self) -> int:
def get_min_capability(cls) -> int:
"""
Get minimum device capability.
"""

View File

@ -18,7 +18,8 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
in a linear transformation.
"""
def get_min_capability(self) -> int:
@classmethod
def get_min_capability(cls) -> int:
# volta and up
return 70

View File

@ -29,7 +29,8 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
raise ValueError(
"group_size must be given when using strategy group")
def get_min_capability(self) -> int:
@classmethod
def get_min_capability(cls) -> int:
# ampere + up
return 80

View File

@ -0,0 +1,105 @@
from typing import Callable, List, Optional
import torch
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationStrategy)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise, create_per_channel_scale_param,
create_per_tensor_scale_param)
from vllm.model_executor.utils import set_weight_attrs
__all__ = ["CompressedTensorsW8A16Fp8"]
SUPPORTED_STRATEGIES = [
QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR
]
class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
def __init__(self, strategy: str, is_static_input_scheme: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
@classmethod
def get_min_capability(cls) -> int:
# ampere and up
return 80
# W8A8-Fp8 kernels support only per-tensor and per-channel cases.
# So if we have a fused module (QKV, MLP) with per tensor scales,
# we expand each scale to its shard's channels.
def process_weights_after_loading(self, layer) -> None:
if self.strategy == QuantizationStrategy.TENSOR:
ws_channelwise = convert_to_channelwise(layer.weight_scale,
layer.logical_widths)
layer.weight_scale = torch.nn.Parameter(ws_channelwise,
requires_grad=False)
# Weights must be transposed for marlin
layer.weight = torch.nn.Parameter(layer.weight.t(),
requires_grad=False)
prepare_fp8_layer_for_marlin(layer, strategy="channel")
def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
# WEIGHT
weight = torch.nn.Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn),
requires_grad=False)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {
"input_dim": 1,
"output_dim": 0,
"weight_loader": weight_loader,
})
# WEIGHT SCALE
layer_kwargs = {"weight_loader": weight_loader}
if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = create_per_channel_scale_param(
output_partition_sizes, **layer_kwargs)
elif self.strategy == QuantizationStrategy.TENSOR:
weight_scale = create_per_tensor_scale_param(
output_partition_sizes, **layer_kwargs)
else:
raise ValueError(
f"Unsupported weight strategy={self.strategy}, "
f"supported strategies are {SUPPORTED_STRATEGIES}")
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE (to deal with converted checkpoints)
if self.is_static_input_scheme:
input_scale = create_per_tensor_scale_param(
output_partition_sizes, **layer_kwargs)
layer.register_parameter("input_scale", input_scale)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return apply_fp8_marlin_linear(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias)

View File

@ -23,7 +23,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
self.is_static_input_scheme = is_static_input_scheme
self.cutlass_fp8_supported = cutlass_fp8_supported()
def get_min_capability(self) -> int:
@classmethod
def get_min_capability(cls) -> int:
# lovelace and up
return 89
@ -77,19 +78,20 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
})
# WEIGHT SCALE
layer_kwargs = {"weight_loader": weight_loader}
if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = create_per_channel_scale_param(
output_partition_sizes, weight_loader=weight_loader)
output_partition_sizes, **layer_kwargs)
else:
assert self.strategy == QuantizationStrategy.TENSOR
weight_scale = create_per_tensor_scale_param(
output_partition_sizes, weight_loader=weight_loader)
output_partition_sizes, **layer_kwargs)
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
if self.is_static_input_scheme:
input_scale = create_per_tensor_scale_param(
output_partition_sizes, weight_loader=weight_loader)
output_partition_sizes, **layer_kwargs)
layer.register_parameter("input_scale", input_scale)
def apply_weights(self,

View File

@ -19,7 +19,8 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
def get_min_capability(self) -> int:
@classmethod
def get_min_capability(cls) -> int:
# turing and up
return 75
@ -68,19 +69,19 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
# WEIGHT SCALE
layer_kwargs = {"weight_loader": weight_loader}
if self.strategy == QuantizationStrategy.CHANNEL:
scale = create_per_channel_scale_param(output_partition_sizes,
**layer_kwargs)
weight_scale = create_per_channel_scale_param(
output_partition_sizes, **layer_kwargs)
else:
assert self.strategy == QuantizationStrategy.TENSOR
scale = create_per_tensor_scale_param(output_partition_sizes,
**layer_kwargs)
layer.register_parameter("weight_scale", scale)
weight_scale = create_per_tensor_scale_param(
output_partition_sizes, **layer_kwargs)
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
if self.is_static_input_scheme:
scale = create_per_tensor_scale_param(output_partition_sizes,
**layer_kwargs)
layer.register_parameter("input_scale", scale)
input_scale = create_per_tensor_scale_param(
output_partition_sizes, **layer_kwargs)
layer.register_parameter("input_scale", input_scale)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:

View File

@ -42,7 +42,8 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
group_size=self.group_size,
is_sym=True)
def get_min_capability(self) -> int:
@classmethod
def get_min_capability(cls) -> int:
# ampere and up
return 80

View File

@ -18,8 +18,9 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, apply_fp8_linear, create_per_tensor_scale_param,
cutlass_fp8_supported, per_tensor_dequantize, requantize_with_max_scale)
all_close_1d, apply_fp8_linear, convert_to_channelwise,
create_per_tensor_scale_param, cutlass_fp8_supported,
per_tensor_dequantize, requantize_with_max_scale)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import print_warning_once
@ -179,19 +180,29 @@ class Fp8LinearMethod(LinearMethodBase):
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.input_scale = None
# If checkpoint is fp8, requantize the separately quantized logical
# weights into a single fp8 weight with a single weight scale.
# If checkpoint is fp8, handle that there are N scales for N
# shards in a fused module
else:
# Dequant -> Quant with max scale.
max_w_scale, weight = requantize_with_max_scale(
weight=layer.weight,
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths,
)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if self.use_marlin:
weight = layer.weight
weight_scale = convert_to_channelwise(layer.weight_scale,
layer.logical_widths)
# If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight.
else:
# Dequant -> Quant with max scale so we can run per tensor.
weight_scale, weight = requantize_with_max_scale(
weight=layer.weight,
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths,
)
# Update layer with new values.
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
if self.quant_config.activation_scheme == "static":
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)

View File

@ -46,7 +46,8 @@ def apply_fp8_marlin_linear(
return output.reshape(out_shape)
def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
strategy: str = "tensor") -> None:
print_warning_once(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
@ -74,16 +75,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
is_channelwise = (len(layer.weight_scale.shape) > 0
and layer.weight_scale.shape[0] == part_size_n)
if is_channelwise:
scales = layer.weight_scale
else:
scales = layer.weight_scale.repeat(1, part_size_n)
scales = scales.to(layer.orig_dtype).to(device)
scales = layer.weight_scale.to(layer.orig_dtype)
# Permute scales
marlin_scales = marlin_permute_scales(s=scales,
size_k=part_size_k,