[Kernel] Expand FP8 support to Ampere GPUs using FP8 Marlin (#5975)
This commit is contained in:
parent
7cd2ebb025
commit
47f0954af0
@ -171,6 +171,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
||||
"csrc/quantization/fp8/fp8_marlin.cu"
|
||||
"csrc/custom_all_reduce.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
|
||||
|
@ -93,6 +93,11 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
int64_t size_k, int64_t size_n,
|
||||
int64_t num_bits);
|
||||
|
||||
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& workspace,
|
||||
int64_t num_bits, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k);
|
||||
|
||||
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
|
||||
|
||||
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||
|
1308
csrc/quantization/fp8/fp8_marlin.cu
Normal file
1308
csrc/quantization/fp8/fp8_marlin.cu
Normal file
File diff suppressed because it is too large
Load Diff
@ -137,6 +137,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("gptq_marlin_repack", &gptq_marlin_repack);
|
||||
ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
|
||||
|
||||
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
|
||||
ops.def("fp8_marlin_gemm", &fp8_marlin_gemm);
|
||||
ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);
|
||||
|
||||
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||
// quantization.
|
||||
ops.def(
|
||||
|
@ -4,7 +4,8 @@ FP8
|
||||
==================
|
||||
|
||||
vLLM supports FP8 (8-bit floating point) weight and activation quantization using hardware acceleration on GPUs such as Nvidia H100 and AMD MI300x.
|
||||
Currently, only Hopper and Ada Lovelace GPUs are supported.
|
||||
Currently, only Hopper and Ada Lovelace GPUs are officially supported for W8A8.
|
||||
Ampere GPUs are supported for W8A16 (weight-only FP8) utilizing Marlin kernels.
|
||||
Quantization of models with FP8 allows for a 2x reduction in model memory requirements and up to a 1.6x improvement in throughput with minimal impact on accuracy.
|
||||
|
||||
Please visit the HF collection of `quantized FP8 checkpoints of popular LLMs ready to use with vLLM <https://huggingface.co/collections/neuralmagic/fp8-llms-for-vllm-666742ed2b78b7ac8df13127>`_.
|
||||
|
@ -11,7 +11,7 @@ Implementation Volta Turing Ampere Ada Hopper AMD GPU Intel GPU x86
|
||||
AQLM ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
|
||||
AWQ ❌ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
|
||||
DeepSpeedFP ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
|
||||
FP8 ❌ ❌ ❌ ✅ ✅ ❌ ❌ ❌ ❌ ❌
|
||||
FP8 ❌ ❌ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
|
||||
Marlin ❌ ❌ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
|
||||
GPTQ ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
|
||||
SqueezeLLM ✅ ✅ ✅ ✅ ✅ ❌ ❌ ❌ ❌ ❌
|
||||
|
@ -8,7 +8,8 @@ import torch
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
|
||||
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS,
|
||||
marlin_permute_scales)
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
|
||||
@ -16,7 +17,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_perms import (
|
||||
marlin_perm)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
MarlinWorkspace, compute_max_diff, is_marlin_supported, marlin_24_quantize,
|
||||
marlin_quantize, marlin_weights)
|
||||
marlin_quantize, marlin_weights, pack_fp8_to_int32)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
gptq_pack, quantize_weights, sort_weights)
|
||||
|
||||
@ -38,9 +39,11 @@ MNK_FACTORS = [
|
||||
(67, 13, 11),
|
||||
]
|
||||
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
|
||||
def rand_data(shape):
|
||||
return torch.randn(shape, dtype=torch.half, device="cuda")
|
||||
|
||||
def rand_data(shape, dtype=torch.float16):
|
||||
return torch.randn(shape, dtype=dtype, device="cuda")
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_marlin_supported(),
|
||||
@ -217,3 +220,80 @@ def test_marlin_24_gemm(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
|
||||
print("max_diff = {}".format(max_diff))
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_marlin_supported(),
|
||||
reason="Marlin is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS)
|
||||
@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS)
|
||||
@pytest.mark.parametrize("num_bits", [8])
|
||||
@pytest.mark.parametrize("group_size", [-1])
|
||||
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
def test_fp8_marlin_gemm(
|
||||
k_chunk,
|
||||
n_chunk,
|
||||
num_bits,
|
||||
group_size,
|
||||
mnk_factors,
|
||||
dtype,
|
||||
):
|
||||
m_factor, n_factor, k_factor = mnk_factors
|
||||
|
||||
size_m = m_factor
|
||||
size_k = k_chunk * k_factor
|
||||
size_n = n_chunk * n_factor
|
||||
|
||||
print(f"MNK = {size_m} {size_n} {size_k}")
|
||||
print(f"groupsize = {group_size}")
|
||||
|
||||
a_input = rand_data((size_m, size_k), dtype=dtype)
|
||||
b_weight = rand_data((size_k, size_n), dtype=dtype)
|
||||
|
||||
# WEIGHTS
|
||||
fp8_weight, weight_scale = ops.scaled_fp8_quant(b_weight, scale=None)
|
||||
# Repack weights to gptq format (packed int32 elements)
|
||||
packed_gptq_qweight = pack_fp8_to_int32(fp8_weight)
|
||||
# Repack weights to marlin format
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
b_q_weight=packed_gptq_qweight,
|
||||
perm=torch.empty(0, dtype=torch.int, device="cuda"),
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
num_bits=8,
|
||||
)
|
||||
|
||||
# WEIGHT SCALES
|
||||
# Currently Marlin doesn't support per-tensor scales, so we
|
||||
# expand it to channelwise
|
||||
scales = weight_scale.repeat(1, size_n).to(a_input.dtype).to("cuda")
|
||||
# Permute scales
|
||||
marlin_scales = marlin_permute_scales(
|
||||
s=scales,
|
||||
size_k=size_k,
|
||||
size_n=size_n,
|
||||
group_size=-1,
|
||||
num_bits=8,
|
||||
)
|
||||
|
||||
workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_MAX_PARALLEL)
|
||||
|
||||
output = ops.fp8_marlin_gemm(
|
||||
a=a_input,
|
||||
b_q_weight=marlin_qweight,
|
||||
b_scales=marlin_scales,
|
||||
workspace=workspace.scratch,
|
||||
num_bits=num_bits,
|
||||
size_m=a_input.shape[0],
|
||||
size_n=b_weight.shape[1],
|
||||
size_k=a_input.shape[1],
|
||||
)
|
||||
output_ref = torch.matmul(a_input, b_weight)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
max_diff = compute_max_diff(output, output_ref)
|
||||
print("max_diff = {}".format(max_diff))
|
||||
|
||||
assert max_diff < 0.04
|
||||
|
@ -6,7 +6,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm._custom_ops import scaled_fp8_quant
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||
|
||||
MODELS = [
|
||||
@ -35,7 +35,16 @@ def test_load_fp16_model(vllm_runner) -> None:
|
||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
||||
fc1 = model.model.decoder.layers[0].fc1
|
||||
assert isinstance(fc1.quant_method, Fp8LinearMethod)
|
||||
assert fc1.weight.dtype == torch.float8_e4m3fn
|
||||
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
if capability >= 89:
|
||||
# For GPUs with hardware support, we keep weights in fp8
|
||||
assert fc1.weight.dtype == torch.float8_e4m3fn
|
||||
else:
|
||||
# For GPUs without hardware support, we pack the fp8 weights
|
||||
# for weight-only quantization using Marlin kernels
|
||||
assert fc1.weight.dtype == torch.int32
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
||||
@ -63,7 +72,7 @@ def test_scaled_fp8_quant(dtype) -> None:
|
||||
x = (torch.randn(size=(11, 11), device="cuda") * 13).to(dtype)
|
||||
|
||||
# Dynamic quantization
|
||||
ref_y, inv_scale = scaled_fp8_quant(x, None)
|
||||
ref_y, inv_scale = ops.scaled_fp8_quant(x, None)
|
||||
ref_y = per_tensor_dequantize(ref_y, inv_scale, dtype)
|
||||
|
||||
# Reference dynamic quantizaton
|
||||
@ -71,11 +80,11 @@ def test_scaled_fp8_quant(dtype) -> None:
|
||||
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
|
||||
|
||||
# Static quantization
|
||||
y, _ = scaled_fp8_quant(x, inv_scale)
|
||||
y, _ = ops.scaled_fp8_quant(x, inv_scale)
|
||||
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
|
||||
|
||||
# Padding
|
||||
y, _ = scaled_fp8_quant(x, inv_scale, batch_dim_padding=17)
|
||||
y, _ = ops.scaled_fp8_quant(x, inv_scale, batch_dim_padding=17)
|
||||
assert y.shape[0] == 17
|
||||
assert torch.allclose(
|
||||
ref_y,
|
||||
|
@ -271,6 +271,15 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
size_k, is_k_full)
|
||||
|
||||
|
||||
# fp8 marlin
|
||||
def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
|
||||
b_scales: torch.Tensor, workspace: torch.Tensor,
|
||||
num_bits: int, size_m: int, size_n: int,
|
||||
size_k: int) -> torch.Tensor:
|
||||
return torch.ops._C.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
|
||||
num_bits, size_m, size_n, size_k)
|
||||
|
||||
|
||||
# fp8
|
||||
def scaled_fp8_quant(
|
||||
input: torch.Tensor,
|
||||
|
@ -11,6 +11,11 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQMarlinState,
|
||||
marlin_permute_scales)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
pack_fp8_to_int32)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import print_warning_once
|
||||
@ -54,7 +59,7 @@ class Fp8Config(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 89
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
@ -106,6 +111,12 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
self.quant_config = quant_config
|
||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
capability = current_platform.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
self.use_marlin = capability < 89
|
||||
|
||||
def _create_scale_param(
|
||||
self,
|
||||
scale_name: str,
|
||||
@ -139,6 +150,10 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
layer.process_after_load = True
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.orig_dtype = params_dtype
|
||||
|
||||
# WEIGHT
|
||||
weight_dtype = (torch.float8_e4m3fn
|
||||
if self.quant_config.is_checkpoint_fp8_serialized else
|
||||
@ -172,6 +187,65 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
output_partition_sizes=output_partition_sizes,
|
||||
**extra_weight_attrs)
|
||||
|
||||
# For GPUs without FP8 hardware support, we use Marlin for fast
|
||||
# fused dequantization
|
||||
if self.use_marlin:
|
||||
layer.marlin_state = GPTQMarlinState.REPACK
|
||||
|
||||
def prepare_layer_for_marlin(self, layer: Module) -> None:
|
||||
print_warning_once(
|
||||
"Your GPU does not have native support for FP8 computation but "
|
||||
"FP8 quantization is being used. Weight-only FP8 compression will "
|
||||
"be used leveraging the Marlin kernel. This may degrade "
|
||||
"performance for compute-heavy workloads.")
|
||||
|
||||
part_size_n = layer.output_size_per_partition
|
||||
part_size_k = layer.input_size_per_partition
|
||||
|
||||
assert layer.marlin_state == GPTQMarlinState.REPACK
|
||||
layer.marlin_state = GPTQMarlinState.READY
|
||||
|
||||
device = layer.weight.device
|
||||
|
||||
# WEIGHTS
|
||||
# Repack weights to gptq format (packed int32 elements)
|
||||
packed_gptq_qweight = pack_fp8_to_int32(layer.weight)
|
||||
|
||||
# Repack weights to marlin format
|
||||
marlin_qweight = ops.gptq_marlin_repack(
|
||||
b_q_weight=packed_gptq_qweight,
|
||||
perm=torch.empty(0, dtype=torch.int, device=device),
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
num_bits=8,
|
||||
)
|
||||
layer.weight = Parameter(marlin_qweight, requires_grad=False)
|
||||
|
||||
# WEIGHT SCALES
|
||||
# Currently Marlin doesn't support per-tensor scales, so we
|
||||
# expand it to channelwise
|
||||
scales = layer.weight_scale.repeat(1, part_size_n).to(
|
||||
layer.orig_dtype).to(device)
|
||||
# Permute scales
|
||||
marlin_scales = marlin_permute_scales(
|
||||
s=scales,
|
||||
size_k=part_size_k,
|
||||
size_n=part_size_n,
|
||||
group_size=-1,
|
||||
num_bits=8,
|
||||
)
|
||||
layer.weight_scale = Parameter(marlin_scales, requires_grad=False)
|
||||
|
||||
# Allocate marlin workspace
|
||||
max_workspace_size = (
|
||||
part_size_n // GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
|
||||
workspace = torch.zeros(max_workspace_size,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
requires_grad=False)
|
||||
|
||||
layer.workspace = workspace
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if (not hasattr(layer, "process_after_load")
|
||||
or not layer.process_after_load):
|
||||
@ -185,6 +259,8 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
layer.logical_widths = None
|
||||
layer.input_scale = None
|
||||
if self.use_marlin:
|
||||
self.prepare_layer_for_marlin(layer)
|
||||
return
|
||||
|
||||
# If checkpoint is fp8, requantize the separately quantized logical
|
||||
@ -233,44 +309,72 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
raise ValueError(
|
||||
f"Unknown scheme {self.quant_config.activation_scheme}")
|
||||
|
||||
if self.use_marlin:
|
||||
self.prepare_layer_for_marlin(layer)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||
if self.use_marlin:
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the
|
||||
# Marlin kernel for fast weight-only FP8 quantization
|
||||
|
||||
if bias is None and self.cutlass_fp8_supported:
|
||||
qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale)
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (layer.output_size_per_partition, )
|
||||
|
||||
# Fused GEMM_DQ
|
||||
output = ops.cutlass_scaled_mm(
|
||||
qinput,
|
||||
layer.weight,
|
||||
out_dtype=x.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=layer.weight_scale,
|
||||
output = ops.fp8_marlin_gemm(
|
||||
a=reshaped_x,
|
||||
b_q_weight=layer.weight,
|
||||
b_scales=layer.weight_scale,
|
||||
workspace=layer.workspace,
|
||||
num_bits=8,
|
||||
size_m=reshaped_x.shape[0],
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
|
||||
else:
|
||||
qinput, x_scale = ops.scaled_fp8_quant(x,
|
||||
layer.input_scale,
|
||||
batch_dim_padding=17)
|
||||
|
||||
# Fused GEMM_DQ -- note we padded the input above because
|
||||
# torch._scaled_mm is more performant for matrices with
|
||||
# batch dimension > 16. Note that this could change
|
||||
# in the future.
|
||||
output, _ = torch._scaled_mm(
|
||||
qinput,
|
||||
layer.weight,
|
||||
out_dtype=x.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=layer.weight_scale,
|
||||
bias=bias,
|
||||
)
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_scale computed from x
|
||||
# If static, layer.input_scale is scalar and x_scale is input_scale
|
||||
|
||||
if bias is None and self.cutlass_fp8_supported:
|
||||
qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale)
|
||||
|
||||
# Fused GEMM_DQ
|
||||
output = ops.cutlass_scaled_mm(
|
||||
qinput,
|
||||
layer.weight,
|
||||
out_dtype=x.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=layer.weight_scale,
|
||||
)
|
||||
|
||||
else:
|
||||
qinput, x_scale = ops.scaled_fp8_quant(x,
|
||||
layer.input_scale,
|
||||
batch_dim_padding=17)
|
||||
|
||||
# Fused GEMM_DQ -- note we padded the input above because
|
||||
# torch._scaled_mm is more performant for matrices with
|
||||
# batch dimension > 16. Note that this could change
|
||||
# in the future.
|
||||
output, _ = torch._scaled_mm(
|
||||
qinput,
|
||||
layer.weight,
|
||||
out_dtype=x.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=layer.weight_scale,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
return torch.narrow(output, 0, 0, x.shape[0])
|
||||
|
||||
|
@ -14,13 +14,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
get_pack_factor, quantize_weights, sort_weights)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
__cuda_arch = current_platform.get_device_capability()
|
||||
|
||||
MARLIN_TILE = 16
|
||||
|
||||
|
||||
def is_marlin_supported():
|
||||
return __cuda_arch[0] >= 8
|
||||
capability = current_platform.get_device_capability()
|
||||
return capability[0] >= 8
|
||||
|
||||
|
||||
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=MARLIN_TILE):
|
||||
@ -223,3 +222,26 @@ class MarlinWorkspace:
|
||||
self.scratch = torch.zeros(max_workspace_size,
|
||||
dtype=torch.int,
|
||||
device="cuda")
|
||||
|
||||
|
||||
def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Repack FP8 weights to gptq format (packed int32 elements)
|
||||
"""
|
||||
assert fp8_tensor.dtype == torch.float8_e4m3fn
|
||||
assert fp8_tensor.shape[0] % 4 == 0
|
||||
|
||||
# Reshape to prepare for packing
|
||||
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
|
||||
|
||||
# Convert fp8 to uint8 (byte) representation
|
||||
byte_tensor = reshaped.view(torch.uint8)
|
||||
|
||||
# Pack 4 uint8 values into one int32
|
||||
packed = (byte_tensor[:, 0].to(torch.int32) |
|
||||
(byte_tensor[:, 1].to(torch.int32) << 8) |
|
||||
(byte_tensor[:, 2].to(torch.int32) << 16) |
|
||||
(byte_tensor[:, 3].to(torch.int32) << 24))
|
||||
|
||||
return packed.view(fp8_tensor.shape[0] // 4,
|
||||
*fp8_tensor.shape[1:]).contiguous()
|
||||
|
Loading…
x
Reference in New Issue
Block a user