[ Misc ] fp8-marlin
channelwise via compressed-tensors
(#6524)
Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
parent
b75e314fff
commit
889da130e7
@ -0,0 +1,11 @@
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen2-1.5B-Instruct-FP8W8 -b auto -l 1000 -f 5 -t 1
|
||||
model_name: "nm-testing/Qwen2-1.5B-Instruct-FP8W8"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
- name: "exact_match,strict-match"
|
||||
value: 0.578
|
||||
- name: "exact_match,flexible-extract"
|
||||
value: 0.585
|
||||
limit: 1000
|
||||
num_fewshot: 5
|
@ -5,3 +5,4 @@ Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
|
||||
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
|
||||
Qwen2-1.5B-Instruct-FP8W8.yaml
|
||||
|
@ -10,7 +10,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
|
||||
CompressedTensorsScheme, CompressedTensorsUnquantized,
|
||||
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
|
||||
CompressedTensorsW8A8Int8, CompressedTensorsWNA16)
|
||||
CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
|
||||
CompressedTensorsWNA16)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
CompressionFormat, QuantizationArgs, QuantizationStrategy,
|
||||
QuantizationType, find_matched_target, is_activation_quantization_format,
|
||||
@ -100,14 +101,18 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return []
|
||||
|
||||
def _check_scheme_supported(self, min_capability: int):
|
||||
def _check_scheme_supported(self,
|
||||
min_capability: int,
|
||||
error: bool = True) -> bool:
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
if capability < min_capability:
|
||||
supported = capability >= min_capability
|
||||
if error and not supported:
|
||||
raise RuntimeError(
|
||||
"Quantization scheme is not supported for ",
|
||||
f"the current GPU. Min capability: {min_capability}. ",
|
||||
f"Current capability: {capability}.")
|
||||
return supported
|
||||
|
||||
def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
@ -170,6 +175,29 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
# All conditions satisfied.
|
||||
return True
|
||||
|
||||
def _is_fp8_w8a16(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
# Confirm weights quantized.
|
||||
if weight_quant is None:
|
||||
return False
|
||||
|
||||
# Confirm we have floating points.
|
||||
if weight_quant.type != QuantizationType.FLOAT:
|
||||
return False
|
||||
|
||||
# Confirm weight scheme is supported.
|
||||
is_symmetric_weight = weight_quant.symmetric
|
||||
is_static_weight = not weight_quant.dynamic
|
||||
is_per_tensor_or_channel_weight = (weight_quant.strategy in [
|
||||
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
|
||||
])
|
||||
if not (is_symmetric_weight and is_static_weight
|
||||
and is_per_tensor_or_channel_weight):
|
||||
return False
|
||||
|
||||
# All conditions satisfied.
|
||||
return True
|
||||
|
||||
def _is_wNa16_group_channel(self, weight_quant: BaseModel,
|
||||
input_quant: BaseModel) -> bool:
|
||||
input_quant_none = input_quant is None
|
||||
@ -204,9 +232,23 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
# Detect If Activation Quantization.
|
||||
if is_activation_quantization_format(self.quant_format):
|
||||
if self._is_fp8_w8a8(weight_quant, input_quant):
|
||||
is_fp8_w8a8_supported = self._check_scheme_supported(
|
||||
CompressedTensorsW8A8Fp8.get_min_capability(), error=False)
|
||||
if is_fp8_w8a8_supported:
|
||||
return CompressedTensorsW8A8Fp8(
|
||||
strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=(not input_quant.dynamic))
|
||||
else:
|
||||
return CompressedTensorsW8A16Fp8(
|
||||
strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=(input_quant
|
||||
and not input_quant.dynamic))
|
||||
|
||||
if self._is_fp8_w8a16(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A16Fp8(
|
||||
strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=(input_quant
|
||||
and not input_quant.dynamic))
|
||||
|
||||
if self._is_static_tensor_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8Int8(
|
||||
@ -257,11 +299,10 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
targets=self.target_scheme_map.keys())
|
||||
|
||||
# Find the quant_scheme
|
||||
scheme = self.target_scheme_map[matched_target]
|
||||
|
||||
return self._get_scheme_from_parts(
|
||||
weight_quant=scheme["weights"],
|
||||
input_quant=scheme["input_activations"])
|
||||
scheme_dict = self.target_scheme_map[matched_target]
|
||||
scheme = self._get_scheme_from_parts(
|
||||
weight_quant=scheme_dict["weights"],
|
||||
input_quant=scheme_dict["input_activations"])
|
||||
|
||||
# Raise error if device does not support the scheme
|
||||
# (e.g. fp8 needs ada lovelace)
|
||||
|
@ -4,6 +4,7 @@ from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
|
||||
CompressedTensorsW4A16Sparse24)
|
||||
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
|
||||
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
|
||||
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
|
||||
from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS,
|
||||
CompressedTensorsWNA16)
|
||||
|
||||
@ -11,6 +12,7 @@ __all__ = [
|
||||
"CompressedTensorsScheme",
|
||||
"CompressedTensorsUnquantized",
|
||||
"CompressedTensorsWNA16",
|
||||
"CompressedTensorsW8A16Fp8",
|
||||
"CompressedTensorsW4A16Sparse24",
|
||||
"CompressedTensorsW8A8Int8",
|
||||
"CompressedTensorsW8A8Fp8",
|
||||
|
@ -12,8 +12,9 @@ class CompressedTensorsScheme(ABC):
|
||||
of different quantization schemes supported by CompressedTensors.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(self) -> int:
|
||||
def get_min_capability(cls) -> int:
|
||||
"""
|
||||
Get minimum device capability.
|
||||
"""
|
||||
|
@ -18,7 +18,8 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
|
||||
in a linear transformation.
|
||||
"""
|
||||
|
||||
def get_min_capability(self) -> int:
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# volta and up
|
||||
return 70
|
||||
|
||||
|
@ -29,7 +29,8 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
|
||||
raise ValueError(
|
||||
"group_size must be given when using strategy group")
|
||||
|
||||
def get_min_capability(self) -> int:
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# ampere + up
|
||||
return 80
|
||||
|
||||
|
@ -0,0 +1,105 @@
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
QuantizationStrategy)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise, create_per_channel_scale_param,
|
||||
create_per_tensor_scale_param)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
__all__ = ["CompressedTensorsW8A16Fp8"]
|
||||
|
||||
SUPPORTED_STRATEGIES = [
|
||||
QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR
|
||||
]
|
||||
|
||||
|
||||
class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
|
||||
|
||||
def __init__(self, strategy: str, is_static_input_scheme: bool):
|
||||
self.strategy = strategy
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# ampere and up
|
||||
return 80
|
||||
|
||||
# W8A8-Fp8 kernels support only per-tensor and per-channel cases.
|
||||
# So if we have a fused module (QKV, MLP) with per tensor scales,
|
||||
# we expand each scale to its shard's channels.
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
if self.strategy == QuantizationStrategy.TENSOR:
|
||||
ws_channelwise = convert_to_channelwise(layer.weight_scale,
|
||||
layer.logical_widths)
|
||||
layer.weight_scale = torch.nn.Parameter(ws_channelwise,
|
||||
requires_grad=False)
|
||||
|
||||
# Weights must be transposed for marlin
|
||||
layer.weight = torch.nn.Parameter(layer.weight.t(),
|
||||
requires_grad=False)
|
||||
|
||||
prepare_fp8_layer_for_marlin(layer, strategy="channel")
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, input_size: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.orig_dtype = params_dtype
|
||||
|
||||
# WEIGHT
|
||||
weight = torch.nn.Parameter(torch.empty(output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=torch.float8_e4m3fn),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, {
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
"weight_loader": weight_loader,
|
||||
})
|
||||
|
||||
# WEIGHT SCALE
|
||||
layer_kwargs = {"weight_loader": weight_loader}
|
||||
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight_scale = create_per_channel_scale_param(
|
||||
output_partition_sizes, **layer_kwargs)
|
||||
elif self.strategy == QuantizationStrategy.TENSOR:
|
||||
weight_scale = create_per_tensor_scale_param(
|
||||
output_partition_sizes, **layer_kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported weight strategy={self.strategy}, "
|
||||
f"supported strategies are {SUPPORTED_STRATEGIES}")
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE (to deal with converted checkpoints)
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = create_per_tensor_scale_param(
|
||||
output_partition_sizes, **layer_kwargs)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
return apply_fp8_marlin_linear(input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias)
|
@ -23,7 +23,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||
|
||||
def get_min_capability(self) -> int:
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# lovelace and up
|
||||
return 89
|
||||
|
||||
@ -77,19 +78,20 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
})
|
||||
|
||||
# WEIGHT SCALE
|
||||
layer_kwargs = {"weight_loader": weight_loader}
|
||||
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight_scale = create_per_channel_scale_param(
|
||||
output_partition_sizes, weight_loader=weight_loader)
|
||||
output_partition_sizes, **layer_kwargs)
|
||||
else:
|
||||
assert self.strategy == QuantizationStrategy.TENSOR
|
||||
weight_scale = create_per_tensor_scale_param(
|
||||
output_partition_sizes, weight_loader=weight_loader)
|
||||
output_partition_sizes, **layer_kwargs)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = create_per_tensor_scale_param(
|
||||
output_partition_sizes, weight_loader=weight_loader)
|
||||
output_partition_sizes, **layer_kwargs)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
def apply_weights(self,
|
||||
|
@ -19,7 +19,8 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
self.strategy = strategy
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
|
||||
def get_min_capability(self) -> int:
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# turing and up
|
||||
return 75
|
||||
|
||||
@ -68,19 +69,19 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
# WEIGHT SCALE
|
||||
layer_kwargs = {"weight_loader": weight_loader}
|
||||
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||
scale = create_per_channel_scale_param(output_partition_sizes,
|
||||
**layer_kwargs)
|
||||
weight_scale = create_per_channel_scale_param(
|
||||
output_partition_sizes, **layer_kwargs)
|
||||
else:
|
||||
assert self.strategy == QuantizationStrategy.TENSOR
|
||||
scale = create_per_tensor_scale_param(output_partition_sizes,
|
||||
**layer_kwargs)
|
||||
layer.register_parameter("weight_scale", scale)
|
||||
weight_scale = create_per_tensor_scale_param(
|
||||
output_partition_sizes, **layer_kwargs)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
scale = create_per_tensor_scale_param(output_partition_sizes,
|
||||
**layer_kwargs)
|
||||
layer.register_parameter("input_scale", scale)
|
||||
input_scale = create_per_tensor_scale_param(
|
||||
output_partition_sizes, **layer_kwargs)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
|
@ -42,7 +42,8 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
group_size=self.group_size,
|
||||
is_sym=True)
|
||||
|
||||
def get_min_capability(self) -> int:
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# ampere and up
|
||||
return 80
|
||||
|
||||
|
@ -18,8 +18,9 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
is_layer_skipped)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d, apply_fp8_linear, create_per_tensor_scale_param,
|
||||
cutlass_fp8_supported, per_tensor_dequantize, requantize_with_max_scale)
|
||||
all_close_1d, apply_fp8_linear, convert_to_channelwise,
|
||||
create_per_tensor_scale_param, cutlass_fp8_supported,
|
||||
per_tensor_dequantize, requantize_with_max_scale)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import print_warning_once
|
||||
@ -179,11 +180,21 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
layer.input_scale = None
|
||||
|
||||
# If checkpoint is fp8, requantize the separately quantized logical
|
||||
# weights into a single fp8 weight with a single weight scale.
|
||||
# If checkpoint is fp8, handle that there are N scales for N
|
||||
# shards in a fused module
|
||||
else:
|
||||
# Dequant -> Quant with max scale.
|
||||
max_w_scale, weight = requantize_with_max_scale(
|
||||
# If using marlin (w8a16), kernel uses channelwise weights,
|
||||
# so extend the weight scales to be channelwise.
|
||||
if self.use_marlin:
|
||||
weight = layer.weight
|
||||
weight_scale = convert_to_channelwise(layer.weight_scale,
|
||||
layer.logical_widths)
|
||||
|
||||
# If using w8a8, torch._scaled_mm needs per tensor, so
|
||||
# requantize the logical shards as a single weight.
|
||||
else:
|
||||
# Dequant -> Quant with max scale so we can run per tensor.
|
||||
weight_scale, weight = requantize_with_max_scale(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
@ -191,7 +202,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
|
||||
# Update layer with new values.
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
if self.quant_config.activation_scheme == "static":
|
||||
layer.input_scale = Parameter(layer.input_scale.max(),
|
||||
requires_grad=False)
|
||||
|
@ -46,7 +46,8 @@ def apply_fp8_marlin_linear(
|
||||
return output.reshape(out_shape)
|
||||
|
||||
|
||||
def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
def prepare_fp8_layer_for_marlin(layer: torch.nn.Module,
|
||||
strategy: str = "tensor") -> None:
|
||||
print_warning_once(
|
||||
"Your GPU does not have native support for FP8 computation but "
|
||||
"FP8 quantization is being used. Weight-only FP8 compression will "
|
||||
@ -74,16 +75,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
|
||||
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
|
||||
|
||||
# WEIGHT SCALES
|
||||
# Currently Marlin doesn't support per-tensor scales, so we
|
||||
# expand it to channelwise
|
||||
is_channelwise = (len(layer.weight_scale.shape) > 0
|
||||
and layer.weight_scale.shape[0] == part_size_n)
|
||||
if is_channelwise:
|
||||
scales = layer.weight_scale
|
||||
else:
|
||||
scales = layer.weight_scale.repeat(1, part_size_n)
|
||||
scales = scales.to(layer.orig_dtype).to(device)
|
||||
|
||||
scales = layer.weight_scale.to(layer.orig_dtype)
|
||||
# Permute scales
|
||||
marlin_scales = marlin_permute_scales(s=scales,
|
||||
size_k=part_size_k,
|
||||
|
Loading…
x
Reference in New Issue
Block a user