[Kernel] Expand FP8 support to Ampere GPUs using FP8 Marlin (#5975)
This commit is contained in:
parent
7cd2ebb025
commit
47f0954af0
@ -171,6 +171,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
|||||||
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
|
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
|
||||||
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||||
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
||||||
|
"csrc/quantization/fp8/fp8_marlin.cu"
|
||||||
"csrc/custom_all_reduce.cu"
|
"csrc/custom_all_reduce.cu"
|
||||||
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
||||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
|
"csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
|
||||||
|
@ -93,6 +93,11 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
|||||||
int64_t size_k, int64_t size_n,
|
int64_t size_k, int64_t size_n,
|
||||||
int64_t num_bits);
|
int64_t num_bits);
|
||||||
|
|
||||||
|
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||||
|
torch::Tensor& b_scales, torch::Tensor& workspace,
|
||||||
|
int64_t num_bits, int64_t size_m, int64_t size_n,
|
||||||
|
int64_t size_k);
|
||||||
|
|
||||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
||||||
|
|
||||||
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||||
|
1308
csrc/quantization/fp8/fp8_marlin.cu
Normal file
1308
csrc/quantization/fp8/fp8_marlin.cu
Normal file
File diff suppressed because it is too large
Load Diff
@ -137,6 +137,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
ops.def("gptq_marlin_repack", &gptq_marlin_repack);
|
ops.def("gptq_marlin_repack", &gptq_marlin_repack);
|
||||||
ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
|
ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
|
||||||
|
|
||||||
|
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
|
||||||
|
ops.def("fp8_marlin_gemm", &fp8_marlin_gemm);
|
||||||
|
ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);
|
||||||
|
|
||||||
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||||
// quantization.
|
// quantization.
|
||||||
ops.def(
|
ops.def(
|
||||||
|
@ -4,7 +4,8 @@ FP8
|
|||||||
==================
|
==================
|
||||||
|
|
||||||
vLLM supports FP8 (8-bit floating point) weight and activation quantization using hardware acceleration on GPUs such as Nvidia H100 and AMD MI300x.
|
vLLM supports FP8 (8-bit floating point) weight and activation quantization using hardware acceleration on GPUs such as Nvidia H100 and AMD MI300x.
|
||||||
Currently, only Hopper and Ada Lovelace GPUs are supported.
|
Currently, only Hopper and Ada Lovelace GPUs are officially supported for W8A8.
|
||||||
|
Ampere GPUs are supported for W8A16 (weight-only FP8) utilizing Marlin kernels.
|
||||||
Quantization of models with FP8 allows for a 2x reduction in model memory requirements and up to a 1.6x improvement in throughput with minimal impact on accuracy.
|
Quantization of models with FP8 allows for a 2x reduction in model memory requirements and up to a 1.6x improvement in throughput with minimal impact on accuracy.
|
||||||
|
|
||||||
Please visit the HF collection of `quantized FP8 checkpoints of popular LLMs ready to use with vLLM <https://huggingface.co/collections/neuralmagic/fp8-llms-for-vllm-666742ed2b78b7ac8df13127>`_.
|
Please visit the HF collection of `quantized FP8 checkpoints of popular LLMs ready to use with vLLM <https://huggingface.co/collections/neuralmagic/fp8-llms-for-vllm-666742ed2b78b7ac8df13127>`_.
|
||||||
|
@ -11,7 +11,7 @@ Implementation Volta Turing Ampere Ada Hopper AMD GPU Intel GPU x86
|
|||||||
AQLM ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
|
AQLM ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
|
||||||
AWQ ❌ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
|
AWQ ❌ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
|
||||||
DeepSpeedFP ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
|
DeepSpeedFP ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
|
||||||
FP8 ❌ ❌ ❌ ✅ ✅ ❌ ❌ ❌ ❌ ❌
|
FP8 ❌ ❌ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
|
||||||
Marlin ❌ ❌ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
|
Marlin ❌ ❌ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
|
||||||
GPTQ ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
|
GPTQ ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
|
||||||
SqueezeLLM ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
|
SqueezeLLM ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
|
||||||
|
@ -8,7 +8,8 @@ import torch
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||||
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
|
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS,
|
||||||
|
marlin_permute_scales)
|
||||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||||
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
|
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
|
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
|
||||||
@ -16,7 +17,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_perms import (
|
|||||||
marlin_perm)
|
marlin_perm)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
MarlinWorkspace, compute_max_diff, is_marlin_supported, marlin_24_quantize,
|
MarlinWorkspace, compute_max_diff, is_marlin_supported, marlin_24_quantize,
|
||||||
marlin_quantize, marlin_weights)
|
marlin_quantize, marlin_weights, pack_fp8_to_int32)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
gptq_pack, quantize_weights, sort_weights)
|
gptq_pack, quantize_weights, sort_weights)
|
||||||
|
|
||||||
@ -38,9 +39,11 @@ MNK_FACTORS = [
|
|||||||
(67, 13, 11),
|
(67, 13, 11),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
DTYPES = [torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
def rand_data(shape):
|
|
||||||
return torch.randn(shape, dtype=torch.half, device="cuda")
|
def rand_data(shape, dtype=torch.float16):
|
||||||
|
return torch.randn(shape, dtype=dtype, device="cuda")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not is_marlin_supported(),
|
@pytest.mark.skipif(not is_marlin_supported(),
|
||||||
@ -217,3 +220,80 @@ def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
|
|||||||
print("max_diff = {}".format(max_diff))
|
print("max_diff = {}".format(max_diff))
|
||||||
|
|
||||||
assert max_diff < 0.04
|
assert max_diff < 0.04
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not is_marlin_supported(),
|
||||||
|
reason="Marlin is not supported on this GPU type.")
|
||||||
|
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||||
|
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||||
|
@pytest.mark.parametrize("num_bits", [8])
|
||||||
|
@pytest.mark.parametrize("group_size", [-1])
|
||||||
|
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
def test_fp8_marlin_gemm(
|
||||||
|
k_chunk,
|
||||||
|
n_chunk,
|
||||||
|
num_bits,
|
||||||
|
group_size,
|
||||||
|
mnk_factors,
|
||||||
|
dtype,
|
||||||
|
):
|
||||||
|
m_factor, n_factor, k_factor = mnk_factors
|
||||||
|
|
||||||
|
size_m = m_factor
|
||||||
|
size_k = k_chunk * k_factor
|
||||||
|
size_n = n_chunk * n_factor
|
||||||
|
|
||||||
|
print(f"MNK = {size_m} {size_n} {size_k}")
|
||||||
|
print(f"groupsize = {group_size}")
|
||||||
|
|
||||||
|
a_input = rand_data((size_m, size_k), dtype=dtype)
|
||||||
|
b_weight = rand_data((size_k, size_n), dtype=dtype)
|
||||||
|
|
||||||
|
# WEIGHTS
|
||||||
|
fp8_weight, weight_scale = ops.scaled_fp8_quant(b_weight, scale=None)
|
||||||
|
# Repack weights to gptq format (packed int32 elements)
|
||||||
|
packed_gptq_qweight = pack_fp8_to_int32(fp8_weight)
|
||||||
|
# Repack weights to marlin format
|
||||||
|
marlin_qweight = ops.gptq_marlin_repack(
|
||||||
|
b_q_weight=packed_gptq_qweight,
|
||||||
|
perm=torch.empty(0, dtype=torch.int, device="cuda"),
|
||||||
|
size_k=size_k,
|
||||||
|
size_n=size_n,
|
||||||
|
num_bits=8,
|
||||||
|
)
|
||||||
|
|
||||||
|
# WEIGHT SCALES
|
||||||
|
# Currently Marlin doesn't support per-tensor scales, so we
|
||||||
|
# expand it to channelwise
|
||||||
|
scales = weight_scale.repeat(1, size_n).to(a_input.dtype).to("cuda")
|
||||||
|
# Permute scales
|
||||||
|
marlin_scales = marlin_permute_scales(
|
||||||
|
s=scales,
|
||||||
|
size_k=size_k,
|
||||||
|
size_n=size_n,
|
||||||
|
group_size=-1,
|
||||||
|
num_bits=8,
|
||||||
|
)
|
||||||
|
|
||||||
|
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||||
|
GPTQ_MARLIN_MAX_PARALLEL)
|
||||||
|
|
||||||
|
output = ops.fp8_marlin_gemm(
|
||||||
|
a=a_input,
|
||||||
|
b_q_weight=marlin_qweight,
|
||||||
|
b_scales=marlin_scales,
|
||||||
|
workspace=workspace.scratch,
|
||||||
|
num_bits=num_bits,
|
||||||
|
size_m=a_input.shape[0],
|
||||||
|
size_n=b_weight.shape[1],
|
||||||
|
size_k=a_input.shape[1],
|
||||||
|
)
|
||||||
|
output_ref = torch.matmul(a_input, b_weight)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
max_diff = compute_max_diff(output, output_ref)
|
||||||
|
print("max_diff = {}".format(max_diff))
|
||||||
|
|
||||||
|
assert max_diff < 0.04
|
||||||
|
@ -6,7 +6,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.quantization.utils import is_quant_method_supported
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
from vllm._custom_ops import scaled_fp8_quant
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
@ -35,7 +35,16 @@ def test_load_fp16_model(vllm_runner) -> None:
|
|||||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||||
fc1 = model.model.decoder.layers[0].fc1
|
fc1 = model.model.decoder.layers[0].fc1
|
||||||
assert isinstance(fc1.quant_method, Fp8LinearMethod)
|
assert isinstance(fc1.quant_method, Fp8LinearMethod)
|
||||||
assert fc1.weight.dtype == torch.float8_e4m3fn
|
|
||||||
|
capability = torch.cuda.get_device_capability()
|
||||||
|
capability = capability[0] * 10 + capability[1]
|
||||||
|
if capability >= 89:
|
||||||
|
# For GPUs with hardware support, we keep weights in fp8
|
||||||
|
assert fc1.weight.dtype == torch.float8_e4m3fn
|
||||||
|
else:
|
||||||
|
# For GPUs without hardware support, we pack the fp8 weights
|
||||||
|
# for weight-only quantization using Marlin kernels
|
||||||
|
assert fc1.weight.dtype == torch.int32
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
||||||
@ -63,7 +72,7 @@ def test_scaled_fp8_quant(dtype) -> None:
|
|||||||
x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype)
|
x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype)
|
||||||
|
|
||||||
# Dynamic quantization
|
# Dynamic quantization
|
||||||
ref_y, inv_scale = scaled_fp8_quant(x, None)
|
ref_y, inv_scale = ops.scaled_fp8_quant(x, None)
|
||||||
ref_y = per_tensor_dequantize(ref_y, inv_scale, dtype)
|
ref_y = per_tensor_dequantize(ref_y, inv_scale, dtype)
|
||||||
|
|
||||||
# Reference dynamic quantizaton
|
# Reference dynamic quantizaton
|
||||||
@ -71,11 +80,11 @@ def test_scaled_fp8_quant(dtype) -> None:
|
|||||||
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
|
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
|
||||||
|
|
||||||
# Static quantization
|
# Static quantization
|
||||||
y, _ = scaled_fp8_quant(x, inv_scale)
|
y, _ = ops.scaled_fp8_quant(x, inv_scale)
|
||||||
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
|
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
|
||||||
|
|
||||||
# Padding
|
# Padding
|
||||||
y, _ = scaled_fp8_quant(x, inv_scale, batch_dim_padding=17)
|
y, _ = ops.scaled_fp8_quant(x, inv_scale, batch_dim_padding=17)
|
||||||
assert y.shape[0] == 17
|
assert y.shape[0] == 17
|
||||||
assert torch.allclose(
|
assert torch.allclose(
|
||||||
ref_y,
|
ref_y,
|
||||||
|
@ -271,6 +271,15 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
|||||||
size_k, is_k_full)
|
size_k, is_k_full)
|
||||||
|
|
||||||
|
|
||||||
|
# fp8 marlin
|
||||||
|
def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||||
|
b_scales: torch.Tensor, workspace: torch.Tensor,
|
||||||
|
num_bits: int, size_m: int, size_n: int,
|
||||||
|
size_k: int) -> torch.Tensor:
|
||||||
|
return torch.ops._C.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
|
||||||
|
num_bits, size_m, size_n, size_k)
|
||||||
|
|
||||||
|
|
||||||
# fp8
|
# fp8
|
||||||
def scaled_fp8_quant(
|
def scaled_fp8_quant(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
|
@ -11,6 +11,11 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
|||||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||||
|
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQMarlinState,
|
||||||
|
marlin_permute_scales)
|
||||||
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
|
pack_fp8_to_int32)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import print_warning_once
|
from vllm.utils import print_warning_once
|
||||||
@ -54,7 +59,7 @@ class Fp8Config(QuantizationConfig):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
return 89
|
return 80
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config_filenames(cls) -> List[str]:
|
def get_config_filenames(cls) -> List[str]:
|
||||||
@ -106,6 +111,12 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||||
|
|
||||||
|
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||||
|
# kernel for fast weight-only FP8 quantization
|
||||||
|
capability = current_platform.get_device_capability()
|
||||||
|
capability = capability[0] * 10 + capability[1]
|
||||||
|
self.use_marlin = capability < 89
|
||||||
|
|
||||||
def _create_scale_param(
|
def _create_scale_param(
|
||||||
self,
|
self,
|
||||||
scale_name: str,
|
scale_name: str,
|
||||||
@ -139,6 +150,10 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
layer.process_after_load = True
|
layer.process_after_load = True
|
||||||
layer.logical_widths = output_partition_sizes
|
layer.logical_widths = output_partition_sizes
|
||||||
|
|
||||||
|
layer.input_size_per_partition = input_size_per_partition
|
||||||
|
layer.output_size_per_partition = output_size_per_partition
|
||||||
|
layer.orig_dtype = params_dtype
|
||||||
|
|
||||||
# WEIGHT
|
# WEIGHT
|
||||||
weight_dtype = (torch.float8_e4m3fn
|
weight_dtype = (torch.float8_e4m3fn
|
||||||
if self.quant_config.is_checkpoint_fp8_serialized else
|
if self.quant_config.is_checkpoint_fp8_serialized else
|
||||||
@ -172,6 +187,65 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
output_partition_sizes=output_partition_sizes,
|
output_partition_sizes=output_partition_sizes,
|
||||||
**extra_weight_attrs)
|
**extra_weight_attrs)
|
||||||
|
|
||||||
|
# For GPUs without FP8 hardware support, we use Marlin for fast
|
||||||
|
# fused dequantization
|
||||||
|
if self.use_marlin:
|
||||||
|
layer.marlin_state = GPTQMarlinState.REPACK
|
||||||
|
|
||||||
|
def prepare_layer_for_marlin(self, layer: Module) -> None:
|
||||||
|
print_warning_once(
|
||||||
|
"Your GPU does not have native support for FP8 computation but "
|
||||||
|
"FP8 quantization is being used. Weight-only FP8 compression will "
|
||||||
|
"be used leveraging the Marlin kernel. This may degrade "
|
||||||
|
"performance for compute-heavy workloads.")
|
||||||
|
|
||||||
|
part_size_n = layer.output_size_per_partition
|
||||||
|
part_size_k = layer.input_size_per_partition
|
||||||
|
|
||||||
|
assert layer.marlin_state == GPTQMarlinState.REPACK
|
||||||
|
layer.marlin_state = GPTQMarlinState.READY
|
||||||
|
|
||||||
|
device = layer.weight.device
|
||||||
|
|
||||||
|
# WEIGHTS
|
||||||
|
# Repack weights to gptq format (packed int32 elements)
|
||||||
|
packed_gptq_qweight = pack_fp8_to_int32(layer.weight)
|
||||||
|
|
||||||
|
# Repack weights to marlin format
|
||||||
|
marlin_qweight = ops.gptq_marlin_repack(
|
||||||
|
b_q_weight=packed_gptq_qweight,
|
||||||
|
perm=torch.empty(0, dtype=torch.int, device=device),
|
||||||
|
size_k=part_size_k,
|
||||||
|
size_n=part_size_n,
|
||||||
|
num_bits=8,
|
||||||
|
)
|
||||||
|
layer.weight = Parameter(marlin_qweight, requires_grad=False)
|
||||||
|
|
||||||
|
# WEIGHT SCALES
|
||||||
|
# Currently Marlin doesn't support per-tensor scales, so we
|
||||||
|
# expand it to channelwise
|
||||||
|
scales = layer.weight_scale.repeat(1, part_size_n).to(
|
||||||
|
layer.orig_dtype).to(device)
|
||||||
|
# Permute scales
|
||||||
|
marlin_scales = marlin_permute_scales(
|
||||||
|
s=scales,
|
||||||
|
size_k=part_size_k,
|
||||||
|
size_n=part_size_n,
|
||||||
|
group_size=-1,
|
||||||
|
num_bits=8,
|
||||||
|
)
|
||||||
|
layer.weight_scale = Parameter(marlin_scales, requires_grad=False)
|
||||||
|
|
||||||
|
# Allocate marlin workspace
|
||||||
|
max_workspace_size = (
|
||||||
|
part_size_n // GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
|
||||||
|
workspace = torch.zeros(max_workspace_size,
|
||||||
|
dtype=torch.int,
|
||||||
|
device=device,
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
layer.workspace = workspace
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
if (not hasattr(layer, "process_after_load")
|
if (not hasattr(layer, "process_after_load")
|
||||||
or not layer.process_after_load):
|
or not layer.process_after_load):
|
||||||
@ -185,6 +259,8 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||||
layer.logical_widths = None
|
layer.logical_widths = None
|
||||||
layer.input_scale = None
|
layer.input_scale = None
|
||||||
|
if self.use_marlin:
|
||||||
|
self.prepare_layer_for_marlin(layer)
|
||||||
return
|
return
|
||||||
|
|
||||||
# If checkpoint is fp8, requantize the separately quantized logical
|
# If checkpoint is fp8, requantize the separately quantized logical
|
||||||
@ -233,44 +309,72 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown scheme {self.quant_config.activation_scheme}")
|
f"Unknown scheme {self.quant_config.activation_scheme}")
|
||||||
|
|
||||||
|
if self.use_marlin:
|
||||||
|
self.prepare_layer_for_marlin(layer)
|
||||||
|
|
||||||
def apply(self,
|
def apply(self,
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
|
||||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
if self.use_marlin:
|
||||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
# For GPUs that lack FP8 hardware support, we can leverage the
|
||||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
# Marlin kernel for fast weight-only FP8 quantization
|
||||||
|
|
||||||
if bias is None and self.cutlass_fp8_supported:
|
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||||
qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale)
|
out_shape = x.shape[:-1] + (layer.output_size_per_partition, )
|
||||||
|
|
||||||
# Fused GEMM_DQ
|
output = ops.fp8_marlin_gemm(
|
||||||
output = ops.cutlass_scaled_mm(
|
a=reshaped_x,
|
||||||
qinput,
|
b_q_weight=layer.weight,
|
||||||
layer.weight,
|
b_scales=layer.weight_scale,
|
||||||
out_dtype=x.dtype,
|
workspace=layer.workspace,
|
||||||
scale_a=x_scale,
|
num_bits=8,
|
||||||
scale_b=layer.weight_scale,
|
size_m=reshaped_x.shape[0],
|
||||||
|
size_n=layer.output_size_per_partition,
|
||||||
|
size_k=layer.input_size_per_partition,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
output.add_(bias) # In-place add
|
||||||
|
|
||||||
|
return output.reshape(out_shape)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
qinput, x_scale = ops.scaled_fp8_quant(x,
|
|
||||||
layer.input_scale,
|
|
||||||
batch_dim_padding=17)
|
|
||||||
|
|
||||||
# Fused GEMM_DQ -- note we padded the input above because
|
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||||
# torch._scaled_mm is more performant for matrices with
|
# If dynamic, layer.input_scale is None and x_scale computed from x
|
||||||
# batch dimension > 16. Note that this could change
|
# If static, layer.input_scale is scalar and x_scale is input_scale
|
||||||
# in the future.
|
|
||||||
output, _ = torch._scaled_mm(
|
if bias is None and self.cutlass_fp8_supported:
|
||||||
qinput,
|
qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale)
|
||||||
layer.weight,
|
|
||||||
out_dtype=x.dtype,
|
# Fused GEMM_DQ
|
||||||
scale_a=x_scale,
|
output = ops.cutlass_scaled_mm(
|
||||||
scale_b=layer.weight_scale,
|
qinput,
|
||||||
bias=bias,
|
layer.weight,
|
||||||
)
|
out_dtype=x.dtype,
|
||||||
|
scale_a=x_scale,
|
||||||
|
scale_b=layer.weight_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
qinput, x_scale = ops.scaled_fp8_quant(x,
|
||||||
|
layer.input_scale,
|
||||||
|
batch_dim_padding=17)
|
||||||
|
|
||||||
|
# Fused GEMM_DQ -- note we padded the input above because
|
||||||
|
# torch._scaled_mm is more performant for matrices with
|
||||||
|
# batch dimension > 16. Note that this could change
|
||||||
|
# in the future.
|
||||||
|
output, _ = torch._scaled_mm(
|
||||||
|
qinput,
|
||||||
|
layer.weight,
|
||||||
|
out_dtype=x.dtype,
|
||||||
|
scale_a=x_scale,
|
||||||
|
scale_b=layer.weight_scale,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
|
|
||||||
return torch.narrow(output, 0, 0, x.shape[0])
|
return torch.narrow(output, 0, 0, x.shape[0])
|
||||||
|
|
||||||
|
@ -14,13 +14,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|||||||
get_pack_factor, quantize_weights, sort_weights)
|
get_pack_factor, quantize_weights, sort_weights)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
__cuda_arch = current_platform.get_device_capability()
|
|
||||||
|
|
||||||
MARLIN_TILE = 16
|
MARLIN_TILE = 16
|
||||||
|
|
||||||
|
|
||||||
def is_marlin_supported():
|
def is_marlin_supported():
|
||||||
return __cuda_arch[0] >= 8
|
capability = current_platform.get_device_capability()
|
||||||
|
return capability[0] >= 8
|
||||||
|
|
||||||
|
|
||||||
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=MARLIN_TILE):
|
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=MARLIN_TILE):
|
||||||
@ -223,3 +222,26 @@ class MarlinWorkspace:
|
|||||||
self.scratch = torch.zeros(max_workspace_size,
|
self.scratch = torch.zeros(max_workspace_size,
|
||||||
dtype=torch.int,
|
dtype=torch.int,
|
||||||
device="cuda")
|
device="cuda")
|
||||||
|
|
||||||
|
|
||||||
|
def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Repack FP8 weights to gptq format (packed int32 elements)
|
||||||
|
"""
|
||||||
|
assert fp8_tensor.dtype == torch.float8_e4m3fn
|
||||||
|
assert fp8_tensor.shape[0] % 4 == 0
|
||||||
|
|
||||||
|
# Reshape to prepare for packing
|
||||||
|
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
|
||||||
|
|
||||||
|
# Convert fp8 to uint8 (byte) representation
|
||||||
|
byte_tensor = reshaped.view(torch.uint8)
|
||||||
|
|
||||||
|
# Pack 4 uint8 values into one int32
|
||||||
|
packed = (byte_tensor[:, 0].to(torch.int32) |
|
||||||
|
(byte_tensor[:, 1].to(torch.int32) << 8) |
|
||||||
|
(byte_tensor[:, 2].to(torch.int32) << 16) |
|
||||||
|
(byte_tensor[:, 3].to(torch.int32) << 24))
|
||||||
|
|
||||||
|
return packed.view(fp8_tensor.shape[0] // 4,
|
||||||
|
*fp8_tensor.shape[1:]).contiguous()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user