[Kernel] Expand FP8 support to Ampere GPUs using FP8 Marlin (#5975)

This commit is contained in:
Michael Goin 2024-07-03 13:38:00 -04:00 committed by GitHub
parent 7cd2ebb025
commit 47f0954af0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 1585 additions and 42 deletions

View File

@ -171,6 +171,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/quantization/fp8/fp8_marlin.cu"
"csrc/custom_all_reduce.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"

View File

@ -93,6 +93,11 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits);
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k);
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,

File diff suppressed because it is too large Load Diff

View File

@ -137,6 +137,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("gptq_marlin_repack", &gptq_marlin_repack);
ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
ops.def("fp8_marlin_gemm", &fp8_marlin_gemm);
ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization.
ops.def(

View File

@ -4,7 +4,8 @@ FP8
==================
vLLM supports FP8 (8-bit floating point) weight and activation quantization using hardware acceleration on GPUs such as Nvidia H100 and AMD MI300x.
Currently, only Hopper and Ada Lovelace GPUs are supported.
Currently, only Hopper and Ada Lovelace GPUs are officially supported for W8A8.
Ampere GPUs are supported for W8A16 (weight-only FP8) utilizing Marlin kernels.
Quantization of models with FP8 allows for a 2x reduction in model memory requirements and up to a 1.6x improvement in throughput with minimal impact on accuracy.
Please visit the HF collection of `quantized FP8 checkpoints of popular LLMs ready to use with vLLM <https://huggingface.co/collections/neuralmagic/fp8-llms-for-vllm-666742ed2b78b7ac8df13127>`_.

View File

@ -11,7 +11,7 @@ Implementation Volta Turing Ampere Ada Hopper AMD GPU Intel GPU x86
AQLM ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
AWQ ❌ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
DeepSpeedFP ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
FP8 ❌ ❌ ✅ ✅ ❌ ❌ ❌ ❌ ❌
FP8 ❌ ❌ ✅ ✅ ❌ ❌ ❌ ❌ ❌
Marlin ❌ ❌ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
GPTQ ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
SqueezeLLM ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌

View File

@ -8,7 +8,8 @@ import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS,
marlin_permute_scales)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
@ -16,7 +17,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_perms import (
marlin_perm)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MarlinWorkspace, compute_max_diff, is_marlin_supported, marlin_24_quantize,
marlin_quantize, marlin_weights)
marlin_quantize, marlin_weights, pack_fp8_to_int32)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
gptq_pack, quantize_weights, sort_weights)
@ -38,9 +39,11 @@ MNK_FACTORS = [
(67, 13, 11),
]
DTYPES = [torch.float16, torch.bfloat16]
def rand_data(shape):
return torch.randn(shape, dtype=torch.half, device="cuda")
def rand_data(shape, dtype=torch.float16):
return torch.randn(shape, dtype=dtype, device="cuda")
@pytest.mark.skipif(not is_marlin_supported(),
@ -217,3 +220,80 @@ def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
print("max_diff = {}".format(max_diff))
assert max_diff < 0.04
@pytest.mark.skipif(not is_marlin_supported(),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
@pytest.mark.parametrize("num_bits", [8])
@pytest.mark.parametrize("group_size", [-1])
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("dtype", DTYPES)
def test_fp8_marlin_gemm(
k_chunk,
n_chunk,
num_bits,
group_size,
mnk_factors,
dtype,
):
m_factor, n_factor, k_factor = mnk_factors
size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
print(f"MNK = {size_m} {size_n} {size_k}")
print(f"groupsize = {group_size}")
a_input = rand_data((size_m, size_k), dtype=dtype)
b_weight = rand_data((size_k, size_n), dtype=dtype)
# WEIGHTS
fp8_weight, weight_scale = ops.scaled_fp8_quant(b_weight, scale=None)
# Repack weights to gptq format (packed int32 elements)
packed_gptq_qweight = pack_fp8_to_int32(fp8_weight)
# Repack weights to marlin format
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=packed_gptq_qweight,
perm=torch.empty(0, dtype=torch.int, device="cuda"),
size_k=size_k,
size_n=size_n,
num_bits=8,
)
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
scales = weight_scale.repeat(1, size_n).to(a_input.dtype).to("cuda")
# Permute scales
marlin_scales = marlin_permute_scales(
s=scales,
size_k=size_k,
size_n=size_n,
group_size=-1,
num_bits=8,
)
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
output = ops.fp8_marlin_gemm(
a=a_input,
b_q_weight=marlin_qweight,
b_scales=marlin_scales,
workspace=workspace.scratch,
num_bits=num_bits,
size_m=a_input.shape[0],
size_n=b_weight.shape[1],
size_k=a_input.shape[1],
)
output_ref = torch.matmul(a_input, b_weight)
torch.cuda.synchronize()
max_diff = compute_max_diff(output, output_ref)
print("max_diff = {}".format(max_diff))
assert max_diff < 0.04

View File

@ -6,7 +6,7 @@ import pytest
import torch
from tests.quantization.utils import is_quant_method_supported
from vllm._custom_ops import scaled_fp8_quant
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
MODELS = [
@ -35,7 +35,16 @@ def test_load_fp16_model(vllm_runner) -> None:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
fc1 = model.model.decoder.layers[0].fc1
assert isinstance(fc1.quant_method, Fp8LinearMethod)
assert fc1.weight.dtype == torch.float8_e4m3fn
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability >= 89:
# For GPUs with hardware support, we keep weights in fp8
assert fc1.weight.dtype == torch.float8_e4m3fn
else:
# For GPUs without hardware support, we pack the fp8 weights
# for weight-only quantization using Marlin kernels
assert fc1.weight.dtype == torch.int32
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
@ -63,7 +72,7 @@ def test_scaled_fp8_quant(dtype) -> None:
x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype)
# Dynamic quantization
ref_y, inv_scale = scaled_fp8_quant(x, None)
ref_y, inv_scale = ops.scaled_fp8_quant(x, None)
ref_y = per_tensor_dequantize(ref_y, inv_scale, dtype)
# Reference dynamic quantizaton
@ -71,11 +80,11 @@ def test_scaled_fp8_quant(dtype) -> None:
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
# Static quantization
y, _ = scaled_fp8_quant(x, inv_scale)
y, _ = ops.scaled_fp8_quant(x, inv_scale)
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
# Padding
y, _ = scaled_fp8_quant(x, inv_scale, batch_dim_padding=17)
y, _ = ops.scaled_fp8_quant(x, inv_scale, batch_dim_padding=17)
assert y.shape[0] == 17
assert torch.allclose(
ref_y,

View File

@ -271,6 +271,15 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
size_k, is_k_full)
# fp8 marlin
def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
num_bits: int, size_m: int, size_n: int,
size_k: int) -> torch.Tensor:
return torch.ops._C.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
num_bits, size_m, size_n, size_k)
# fp8
def scaled_fp8_quant(
input: torch.Tensor,

View File

@ -11,6 +11,11 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQMarlinState,
marlin_permute_scales)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
pack_fp8_to_int32)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import print_warning_once
@ -54,7 +59,7 @@ class Fp8Config(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
return 89
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
@ -106,6 +111,12 @@ class Fp8LinearMethod(LinearMethodBase):
self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported()
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
self.use_marlin = capability < 89
def _create_scale_param(
self,
scale_name: str,
@ -139,6 +150,10 @@ class Fp8LinearMethod(LinearMethodBase):
layer.process_after_load = True
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
# WEIGHT
weight_dtype = (torch.float8_e4m3fn
if self.quant_config.is_checkpoint_fp8_serialized else
@ -172,6 +187,65 @@ class Fp8LinearMethod(LinearMethodBase):
output_partition_sizes=output_partition_sizes,
**extra_weight_attrs)
# For GPUs without FP8 hardware support, we use Marlin for fast
# fused dequantization
if self.use_marlin:
layer.marlin_state = GPTQMarlinState.REPACK
def prepare_layer_for_marlin(self, layer: Module) -> None:
print_warning_once(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads.")
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
assert layer.marlin_state == GPTQMarlinState.REPACK
layer.marlin_state = GPTQMarlinState.READY
device = layer.weight.device
# WEIGHTS
# Repack weights to gptq format (packed int32 elements)
packed_gptq_qweight = pack_fp8_to_int32(layer.weight)
# Repack weights to marlin format
marlin_qweight = ops.gptq_marlin_repack(
b_q_weight=packed_gptq_qweight,
perm=torch.empty(0, dtype=torch.int, device=device),
size_k=part_size_k,
size_n=part_size_n,
num_bits=8,
)
layer.weight = Parameter(marlin_qweight, requires_grad=False)
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
scales = layer.weight_scale.repeat(1, part_size_n).to(
layer.orig_dtype).to(device)
# Permute scales
marlin_scales = marlin_permute_scales(
s=scales,
size_k=part_size_k,
size_n=part_size_n,
group_size=-1,
num_bits=8,
)
layer.weight_scale = Parameter(marlin_scales, requires_grad=False)
# Allocate marlin workspace
max_workspace_size = (
part_size_n // GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device=device,
requires_grad=False)
layer.workspace = workspace
def process_weights_after_loading(self, layer: Module) -> None:
if (not hasattr(layer, "process_after_load")
or not layer.process_after_load):
@ -185,6 +259,8 @@ class Fp8LinearMethod(LinearMethodBase):
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.logical_widths = None
layer.input_scale = None
if self.use_marlin:
self.prepare_layer_for_marlin(layer)
return
# If checkpoint is fp8, requantize the separately quantized logical
@ -233,44 +309,72 @@ class Fp8LinearMethod(LinearMethodBase):
raise ValueError(
f"Unknown scheme {self.quant_config.activation_scheme}")
if self.use_marlin:
self.prepare_layer_for_marlin(layer)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
if self.use_marlin:
# For GPUs that lack FP8 hardware support, we can leverage the
# Marlin kernel for fast weight-only FP8 quantization
if bias is None and self.cutlass_fp8_supported:
qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale)
reshaped_x = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (layer.output_size_per_partition, )
# Fused GEMM_DQ
output = ops.cutlass_scaled_mm(
qinput,
layer.weight,
out_dtype=x.dtype,
scale_a=x_scale,
scale_b=layer.weight_scale,
output = ops.fp8_marlin_gemm(
a=reshaped_x,
b_q_weight=layer.weight,
b_scales=layer.weight_scale,
workspace=layer.workspace,
num_bits=8,
size_m=reshaped_x.shape[0],
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
else:
qinput, x_scale = ops.scaled_fp8_quant(x,
layer.input_scale,
batch_dim_padding=17)
# Fused GEMM_DQ -- note we padded the input above because
# torch._scaled_mm is more performant for matrices with
# batch dimension > 16. Note that this could change
# in the future.
output, _ = torch._scaled_mm(
qinput,
layer.weight,
out_dtype=x.dtype,
scale_a=x_scale,
scale_b=layer.weight_scale,
bias=bias,
)
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x
# If static, layer.input_scale is scalar and x_scale is input_scale
if bias is None and self.cutlass_fp8_supported:
qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale)
# Fused GEMM_DQ
output = ops.cutlass_scaled_mm(
qinput,
layer.weight,
out_dtype=x.dtype,
scale_a=x_scale,
scale_b=layer.weight_scale,
)
else:
qinput, x_scale = ops.scaled_fp8_quant(x,
layer.input_scale,
batch_dim_padding=17)
# Fused GEMM_DQ -- note we padded the input above because
# torch._scaled_mm is more performant for matrices with
# batch dimension > 16. Note that this could change
# in the future.
output, _ = torch._scaled_mm(
qinput,
layer.weight,
out_dtype=x.dtype,
scale_a=x_scale,
scale_b=layer.weight_scale,
bias=bias,
)
return torch.narrow(output, 0, 0, x.shape[0])

View File

@ -14,13 +14,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_pack_factor, quantize_weights, sort_weights)
from vllm.platforms import current_platform
__cuda_arch = current_platform.get_device_capability()
MARLIN_TILE = 16
def is_marlin_supported():
return __cuda_arch[0] >= 8
capability = current_platform.get_device_capability()
return capability[0] >= 8
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=MARLIN_TILE):
@ -223,3 +222,26 @@ class MarlinWorkspace:
self.scratch = torch.zeros(max_workspace_size,
dtype=torch.int,
device="cuda")
def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
"""
Repack FP8 weights to gptq format (packed int32 elements)
"""
assert fp8_tensor.dtype == torch.float8_e4m3fn
assert fp8_tensor.shape[0] % 4 == 0
# Reshape to prepare for packing
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
# Convert fp8 to uint8 (byte) representation
byte_tensor = reshaped.view(torch.uint8)
# Pack 4 uint8 values into one int32
packed = (byte_tensor[:, 0].to(torch.int32) |
(byte_tensor[:, 1].to(torch.int32) << 8) |
(byte_tensor[:, 2].to(torch.int32) << 16) |
(byte_tensor[:, 3].to(torch.int32) << 24))
return packed.view(fp8_tensor.shape[0] // 4,
*fp8_tensor.shape[1:]).contiguous()