diff --git a/CMakeLists.txt b/CMakeLists.txt index a0fd346c..244ceb72 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -264,6 +264,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/custom_all_reduce.cu" "csrc/permute_cols.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" + "csrc/quantization/fp4/nvfp4_quant_entry.cu" "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" "csrc/sparse/cutlass/sparse_compressor_entry.cu" "csrc/cutlass_extensions/common.cpp") @@ -377,6 +378,23 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() endif() + # FP4 Archs and flags + cuda_archs_loose_intersection(FP4_ARCHS "10.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS) + set(SRCS + "csrc/quantization/fp4/nvfp4_quant_kernels.cu" + ) + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${FP4_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4=1") + message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}") + else() + message(STATUS "Not building NVFP4 as no compatible archs were found.") + # clear FP4_ARCHS + set(FP4_ARCHS) + endif() # # Machete kernels diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 1c1c5398..c9cd099b 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -257,9 +257,9 @@ endmacro() # where `<=` is the version comparison operator. # In other words, for each version in `TGT_CUDA_ARCHS` find the highest version # in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`. -# We have special handling for 9.0a, if 9.0a is in `SRC_CUDA_ARCHS` and 9.0 is -# in `TGT_CUDA_ARCHS` then we should remove 9.0a from `SRC_CUDA_ARCHS` and add -# 9.0a to the result (and remove 9.0 from TGT_CUDA_ARCHS). +# We have special handling for x.0a, if x.0a is in `SRC_CUDA_ARCHS` and x.0 is +# in `TGT_CUDA_ARCHS` then we should remove x.0a from `SRC_CUDA_ARCHS` and add +# x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS). # The result is stored in `OUT_CUDA_ARCHS`. # # Example: @@ -272,8 +272,8 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR list(REMOVE_DUPLICATES SRC_CUDA_ARCHS) set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS}) - # if 9.0a is in SRC_CUDA_ARCHS and 9.0 is in CUDA_ARCHS then we should - # remove 9.0a from SRC_CUDA_ARCHS and add 9.0a to _CUDA_ARCHS + # if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should + # remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS set(_CUDA_ARCHS) if ("9.0a" IN_LIST SRC_CUDA_ARCHS) list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a") @@ -283,6 +283,14 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR endif() endif() + if ("10.0a" IN_LIST SRC_CUDA_ARCHS) + list(REMOVE_ITEM SRC_CUDA_ARCHS "10.0a") + if ("10.0" IN_LIST TGT_CUDA_ARCHS) + list(REMOVE_ITEM TGT_CUDA_ARCHS_ "10.0") + set(_CUDA_ARCHS "10.0a") + endif() + endif() + list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) # for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that diff --git a/csrc/cuda_utils.h b/csrc/cuda_utils.h index c3522421..6f79d2b7 100644 --- a/csrc/cuda_utils.h +++ b/csrc/cuda_utils.h @@ -1,5 +1,7 @@ #pragma once +#include + #if defined(__CUDACC__) || defined(_NVHPC_CUDA) #define HOST_DEVICE_INLINE __forceinline__ __host__ __device__ #define DEVICE_INLINE __forceinline__ __device__ @@ -10,6 +12,16 @@ #define HOST_INLINE inline #endif +#define CUDA_CHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \ + cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + int64_t get_device_attribute(int64_t attribute, int64_t device_id); int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id); diff --git a/csrc/cuda_utils_kernels.cu b/csrc/cuda_utils_kernels.cu index d6f9eb64..0627a426 100644 --- a/csrc/cuda_utils_kernels.cu +++ b/csrc/cuda_utils_kernels.cu @@ -1,16 +1,22 @@ +#include "cuda_utils.h" #ifdef USE_ROCM #include #include #endif + int64_t get_device_attribute(int64_t attribute, int64_t device_id) { - int device, value; - if (device_id < 0) { - cudaGetDevice(&device); - } else { - device = device_id; - } - cudaDeviceGetAttribute(&value, static_cast(attribute), - device); + // Return the cached value on subsequent calls + static int value = [=]() { + int device = static_cast(device_id); + if (device < 0) { + CUDA_CHECK(cudaGetDevice(&device)); + } + int value; + CUDA_CHECK(cudaDeviceGetAttribute( + &value, static_cast(attribute), device)); + return static_cast(value); + }(); + return value; } diff --git a/csrc/ops.h b/csrc/ops.h index e39d4ef3..70e864cc 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -195,6 +195,10 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); +void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, + torch::Tensor& output_scale, + torch::Tensor const& input_scale); + void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale); diff --git a/csrc/quantization/fp4/nvfp4_quant_entry.cu b/csrc/quantization/fp4/nvfp4_quant_entry.cu new file mode 100644 index 00000000..b1426c43 --- /dev/null +++ b/csrc/quantization/fp4/nvfp4_quant_entry.cu @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#if defined ENABLE_NVFP4 && ENABLE_NVFP4 +void scaled_fp4_quant_sm100a(torch::Tensor const& output, + torch::Tensor const& input, + torch::Tensor const& output_sf, + torch::Tensor const& input_sf); +#endif + +void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, + torch::Tensor& output_sf, torch::Tensor const& input_sf) { +#if defined ENABLE_NVFP4 && ENABLE_NVFP4 + return scaled_fp4_quant_sm100a(output, input, output_sf, input_sf); +#endif + TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization"); +} diff --git a/csrc/quantization/fp4/nvfp4_quant_kernels.cu b/csrc/quantization/fp4/nvfp4_quant_kernels.cu new file mode 100644 index 00000000..c3b8e9b3 --- /dev/null +++ b/csrc/quantization/fp4/nvfp4_quant_kernels.cu @@ -0,0 +1,379 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +#include +#include + +#include + +#include "cuda_utils.h" + +// Get type2 from type or vice versa (applied to half and bfloat16) +template +struct TypeConverter { + using Type = half2; +}; // keep for generality + +template <> +struct TypeConverter { + using Type = half; +}; + +template <> +struct TypeConverter { + using Type = half2; +}; + +template <> +struct TypeConverter<__nv_bfloat162> { + using Type = __nv_bfloat16; +}; + +template <> +struct TypeConverter<__nv_bfloat16> { + using Type = __nv_bfloat162; +}; + +#define ELTS_PER_THREAD 8 + +constexpr int CVT_FP4_ELTS_PER_THREAD = 8; +constexpr int CVT_FP4_SF_VEC_SIZE = 16; + +// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), + "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7])); + return val; +#else + return 0; +#endif +} + +// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), + "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); + return val; +#else + return 0; +#endif +} + +// Fast reciprocal. +inline __device__ float reciprocal_approximate_ftz(float a) { + float b; + asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); + return b; +} + +template +__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, + int numCols, + SFType* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || + CVT_FP4_NUM_THREADS_PER_SF == 2); + + // One pair of threads write one SF to global memory. + // TODO: stage through smem for packed STG.32 + // is it better than STG.8 from 4 threads ? + if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { + // SF vector index (16 elements share one SF in the K dimension). + int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; + int32_t mIdx = rowIdx; + + // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] + // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] + + int32_t mTileIdx = mIdx / (32 * 4); + // SF vector size 16. + int factor = CVT_FP4_SF_VEC_SIZE * 4; + int32_t numKTiles = (numCols + factor - 1) / factor; + int64_t mTileStride = numKTiles * 32 * 4 * 4; + + int32_t kTileIdx = (kIdx / 4); + int64_t kTileStride = 32 * 4 * 4; + + // M tile layout [32, 4] is column-major. + int32_t outerMIdx = (mIdx % 32); + int64_t outerMStride = 4 * 4; + + int32_t innerMIdx = (mIdx % (32 * 4)) / 32; + int64_t innerMStride = 4; + + int32_t innerKIdx = (kIdx % 4); + int64_t innerKStride = 1; + + // Compute the global offset. + int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + + outerMIdx * outerMStride + innerMIdx * innerMStride + + innerKIdx * innerKStride; + + return reinterpret_cast(SFout) + SFOffset; + } +#endif + return nullptr; +} + +// Define a 16 bytes packed data type. +template +struct PackedVec { + typename TypeConverter::Type elts[4]; +}; + +template <> +struct PackedVec<__nv_fp8_e4m3> { + __nv_fp8x2_e4m3 elts[8]; +}; + +// Quantizes the provided PackedVec into the uint32_t output +template +__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, + uint8_t* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + // Get absolute maximum values among the local 8 values. + auto localMax = __habs2(vec.elts[0]); + + // Local maximum value. + #pragma unroll + for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + localMax = __hmax2(localMax, __habs2(vec.elts[i])); + } + + // Get the absolute maximum among all 16 values (two threads). + localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + // Get the final absolute maximum values. + float vecMax = float(__hmax(localMax.x, localMax.y)); + + // Get the SF (max value of the vector / max value of e2m1). + // maximum value of e2m1 = 6.0. + // TODO: use half as compute data type. + float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); + // 8 bits representation of the SF. + uint8_t fp8SFVal; + // Write the SF to global memory (STG.8). + if constexpr (UE8M0_SF) { + // Extract the 8 exponent bits from float32. + // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. + uint32_t tmp = reinterpret_cast(SFValue) >> 23; + fp8SFVal = tmp & 0xff; + // Convert back to fp32. + reinterpret_cast(SFValue) = tmp << 23; + } else { + // Here SFValue is always positive, so E4M3 is the same as UE4M3. + __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); + reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; + // Convert back to fp32. + SFValue = float(tmp); + } + // Get the output scale. + // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * + // reciprocal(SFScaleVal)) + float outputScale = + SFValue != 0 ? reciprocal_approximate_ftz( + SFValue * reciprocal_approximate_ftz(SFScaleVal)) + : 0.0f; + + if (SFout) { + // Write the SF to global memory (STG.8). + *SFout = fp8SFVal; + } + + // Convert the input to float. + float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; + + #pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + if constexpr (std::is_same_v) { + fp2Vals[i] = __half22float2(vec.elts[i]); + } else { + fp2Vals[i] = __bfloat1622float2(vec.elts[i]); + } + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + // Convert to e2m1 values. + uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); + + // Write the e2m1 values to global memory. + return e2m1Vec; +#else + return 0; +#endif +} + +// Use UE4M3 by default. +template +__global__ void +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(512, 4) cvt_fp16_to_fp4( +#else +cvt_fp16_to_fp4( +#endif + int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, + uint32_t* out, uint32_t* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + using PackedVec = PackedVec; + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = + (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, + "Vec size is not matched."); + + // Get the global scaling factor, which will be applied to the SF. + // Note SFScale is the same as next GEMM's alpha, which is + // (448.f / (Alpha_A / 6.f)). + float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0]; + + // Input tensor row/col loops. + for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { + for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; + colIdx += blockDim.x) { + int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + // Get the output tensor offset. + // Same as inOffset because 8 elements are packed into one uint32_t. + int64_t outOffset = inOffset; + auto& out_pos = out[outOffset]; + + auto sf_out = + cvt_quant_to_fp4_get_sf_out_offset( + rowIdx, colIdx, numCols, SFout); + + out_pos = + cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); + } + } +#endif +} + +template +void invokeFP4Quantization(int m, int n, T const* input, float const* SFScale, + int64_t* output, int32_t* SFOuput, bool useUE8M0, + int multiProcessorCount, cudaStream_t stream) { + // Grid, Block size. + // Each thread converts 8 values. + dim3 block(std::min(int(n / ELTS_PER_THREAD), 512)); + // Get number of blocks per SM (assume we can fully utilize the SM). + int const numBlocksPerSM = 2048 / block.x; + dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); + + // Launch the cvt kernel. + if (useUE8M0) { + cvt_fp16_to_fp4<<>>( + m, n, input, SFScale, reinterpret_cast(output), + reinterpret_cast(SFOuput)); + } else { + cvt_fp16_to_fp4<<>>( + m, n, input, SFScale, reinterpret_cast(output), + reinterpret_cast(SFOuput)); + } +} + +// Instantiate the function. +template void invokeFP4Quantization(int m, int n, half const* input, + float const* SFScale, int64_t* output, + int32_t* SFOuput, bool useUE8M0, + int multiProcessorCount, + cudaStream_t stream); + +template void invokeFP4Quantization(int m, int n, __nv_bfloat16 const* input, + float const* SFScale, int64_t* output, + int32_t* SFOuput, bool useUE8M0, + int multiProcessorCount, + cudaStream_t stream); + +void scaled_fp4_quant_sm100a(torch::Tensor const& output, + torch::Tensor const& input, + torch::Tensor const& output_sf, + torch::Tensor const& input_sf) { + int32_t m = input.size(0); + int32_t n = input.size(1); + + TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); + + int multiProcessorCount = + get_device_attribute(cudaDevAttrMultiProcessorCount, -1); + + auto input_sf_ptr = static_cast(input_sf.data_ptr()); + auto sf_out = static_cast(output_sf.data_ptr()); + auto output_ptr = static_cast(output.data_ptr()); + at::cuda::CUDAGuard device_guard{(char)input.get_device()}; + auto stream = at::cuda::getStreamFromPool(false, input.get_device()); + if (stream == nullptr) { + std::cerr << "Warning: Null CUDA stream" << std::endl; + } + + // We don't support e8m0 scales at this moment. + bool useUE8M0 = false; + + switch (input.scalar_type()) { + case torch::kHalf: { + auto input_ptr = reinterpret_cast(input.data_ptr()); + invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out, + useUE8M0, multiProcessorCount, stream); + break; + } + case torch::kBFloat16: { + auto input_ptr = reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr()); + invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out, + useUE8M0, multiProcessorCount, stream); + break; + } + default: { + std::cerr << "Observing: " << input.scalar_type() + << " for the input datatype which is invalid"; + throw std::runtime_error( + "Unsupported input data type for quantize_to_fp4."); + } + } +} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index c03806f4..784ded26 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -423,6 +423,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA, &dynamic_per_token_scaled_fp8_quant); + // Compute NVFP4 block quantized tensor. + ops.def( + "scaled_fp4_quant(Tensor! output, Tensor input," + " Tensor! output_scale, Tensor input_scale) -> ()"); + ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant); + // Compute int8 quantized tensor for given scaling factor. ops.def( "static_scaled_int8_quant(Tensor! result, Tensor input, Tensor scale," diff --git a/tests/kernels/test_nvfp4_quant.py b/tests/kernels/test_nvfp4_quant.py new file mode 100644 index 00000000..93735fc0 --- /dev/null +++ b/tests/kernels/test_nvfp4_quant.py @@ -0,0 +1,149 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +if not current_platform.has_device_capability(100): + pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", + allow_module_level=True) + +DTYPES = [torch.float16, torch.bfloat16] +SHAPES = [(128, 64), (128, 128), (256, 64), (256, 128)] +PAD_SHAPES = [(90, 64), (150, 64), (128, 48), (128, 80), (150, 80), (90, 48), + (90, 128), (150, 128), (150, 48), (90, 80)] +SEEDS = [42] +CUDA_DEVICES = ['cuda:0'] + +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1fn.max() +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + +# E2M1 to float +# 0111 -> 6 +# 0110 -> 4 +# 0101 -> 3 +# 0100 -> 2 +# 0011 -> 1.5 +# 0010 -> 1 +# 0001 -> 0.5 +# 0000 -> 0 +E2M1_TO_FLOAT32 = [ + 0., 0.5, 1., 1.5, 2., 3., 4., 6., 0., -0.5, -1., -1.5, -2., -3., -4., -6. +] +BLOCK_SIZE = 16 + + +def cast_from_fp4(x, m, n): + # The fp4 values are packed in uint8 as [v_1st | v_2nd] + v_2nd = x & 0xF + v_1st = (x >> 4) & 0xF + c = torch.stack((v_2nd, v_1st), dim=-1) + out = torch.tensor([E2M1_TO_FLOAT32[x] for x in c.flatten()]) + out = out.reshape(m, n).to(torch.float32) + return out + + +def cast_to_fp4(x): + sign = torch.sign(x) + x = torch.abs(x) + x[(x >= 0.0) & (x <= 0.25)] = 0.0 + x[(x > 0.25) & (x < 0.75)] = 0.5 + x[(x >= 0.75) & (x <= 1.25)] = 1.0 + x[(x > 1.25) & (x < 1.75)] = 1.5 + x[(x >= 1.75) & (x <= 2.5)] = 2.0 + x[(x > 2.5) & (x < 3.5)] = 3.0 + x[(x >= 3.5) & (x <= 5.0)] = 4.0 + x[x > 5.0] = 6.0 + return x * sign + + +def get_reciprocal(x): + if isinstance(x, torch.Tensor): + return torch.where(x == 0, torch.tensor(0.0, dtype=x.dtype), 1.0 / x) + elif isinstance(x, (float, int)): + return 0.0 if x == 0 else 1.0 / x + else: + raise TypeError("Input must be a float, int, or a torch.Tensor.") + + +def ref_nvfp4_quant(x, global_scale): + assert global_scale.dtype == torch.float32 + assert x.ndim == 2 + m, n = x.shape + x = torch.reshape(x, (m, n // BLOCK_SIZE, BLOCK_SIZE)) + vec_max = torch.max(torch.abs(x), dim=-1, + keepdim=True)[0].to(torch.float32) + scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX)) + scale = scale.to(torch.float8_e4m3fn).to(torch.float32) + output_scale = get_reciprocal(scale * get_reciprocal(global_scale)) + + scaled_x = x.to(torch.float32) * output_scale + clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n) + return cast_to_fp4(clipped_x), scale.squeeze(-1) + + +def recover_swizzled_scales(scale, m, n): + round_up = lambda x, y: (x + y - 1) // y * y + rounded_m = round_up(m, 128) + scale_n = n // BLOCK_SIZE + rounded_n = round_up(scale_n, 4) + # Recover the swizzled scaling factor to linear layout + tmp = torch.reshape(scale, (1, rounded_m // 128, rounded_n // 4, 32, 4, 4)) + tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) + result = torch.reshape(tmp, (rounded_m, rounded_n)).to(torch.float32) + return result[:m, :scale_n] + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("shape", SHAPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_quantize_to_fp4( + dtype: torch.dtype, + shape: tuple[int, int], + seed: int, + device: str, +) -> None: + current_platform.seed_everything(seed) + torch.set_default_device(device) + + m, n = shape + + x = torch.randn((m, n), dtype=dtype) + tensor_amax = torch.abs(x).max().to(torch.float32) + global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax + out_ref, scale_ref = ref_nvfp4_quant(x, global_scale) + + out, out_scale = ops.scaled_fp4_quant(x, global_scale) + scale_ans = recover_swizzled_scales(out_scale, m, n) + out_ans = cast_from_fp4(out, m, n) + + torch.testing.assert_close(out_ans, out_ref) + torch.testing.assert_close(scale_ans, scale_ref) + + +@pytest.mark.parametrize("pad_shape", PAD_SHAPES) +@torch.inference_mode() +def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None: + dtype = torch.float16 + current_platform.seed_everything(42) + torch.set_default_device('cuda:0') + + m, n = pad_shape + + x = torch.randn((m, n), dtype=dtype) + + tensor_amax = torch.abs(x).max().to(torch.float32) + global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax + out_ref, scale_ref = ref_nvfp4_quant(x, global_scale) + + out, out_scale = ops.scaled_fp4_quant(x, global_scale) + + scale_ans = recover_swizzled_scales(out_scale, m, n) + out_ans = cast_from_fp4(out, m, n) + + torch.testing.assert_close(out_ans, out_ref) + torch.testing.assert_close(scale_ans, scale_ref) diff --git a/tests/test_scalartype.py b/tests/test_scalartype.py index 6e36f2c3..d0e57ea8 100644 --- a/tests/test_scalartype.py +++ b/tests/test_scalartype.py @@ -11,6 +11,7 @@ from vllm.scalar_type import scalar_types (0, 15, scalar_types.uint4), (-8, 7, scalar_types.uint4b8), (-128, 127, scalar_types.uint8b128), + (-6., 6., scalar_types.float4_e2m1fn), (-28., 28., scalar_types.float6_e3m2f), (torch.int8, scalar_types.int8), (torch.uint8, scalar_types.uint8), diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a6823501..67843c17 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -765,6 +765,63 @@ def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: return torch.ops._C.permute_cols(a, perm) +# fp4 +def scaled_fp4_quant( + input: torch.Tensor, + input_global_scale: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP4 and return quantized tensor and scale. + + This function quantizes the last dimension of the given tensor `input`. For + every 16 consecutive elements, a single dynamically computed scaling factor + is shared. This scaling factor is quantized using the `input_global_scale` + and is stored in a swizzled layout (see + https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x). + + Args: + input: The input tensor to be quantized to FP4 + input_global_scale: A scalar scaling factor for the entire tensor. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every + two values are packed into a uint8 and float8_e4m3 scaling factors + in the sizzled layout. + """ + assert input.ndim >= 1, ( + f'input.ndim needs to be >= 1, but got {input.ndim}.') + other_dims = 1 if input.ndim == 1 else -1 + input = input.reshape(other_dims, input.shape[-1]) + m, n = input.shape + block_size = 16 + device = input.device + + assert n % block_size == 0, ( + f'last dim has to be multiple of 16, but got {n}.') + assert input.dtype in (torch.float16, torch.bfloat16), ( + f'input.dtype needs to be fp16 or bf16 but got {input.dtype}.') + + # Two fp4 values will be packed into an uint8. + output = torch.empty((m, n // 2), device=device, dtype=torch.uint8) + + # We use the rounded values to store the swizzled values. Due to the + # requirement of the Tensor Core, the minimum tile is 128x4 for the scales. + # So, we first pad the scales to multiples of 128 and 4. Then, the scales + # (in float8_e4m3fn) are packed into an int32 for every 4 values. More: + # https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x + round_up = lambda x, y: (x + y - 1) // y * y + rounded_m = round_up(m, 128) + scale_n = n // block_size + rounded_n = round_up(scale_n, 4) + output_scale = torch.empty((rounded_m, rounded_n // 4), + device=device, + dtype=torch.int32) + + torch.ops._C.scaled_fp4_quant(output, input, output_scale, + input_global_scale) + output_scale = output_scale.view(torch.float8_e4m3fn) + return output, output_scale + + # fp8 def scaled_fp8_quant( input: torch.Tensor, diff --git a/vllm/scalar_type.py b/vllm/scalar_type.py index 9f6e8592..1d7675dd 100644 --- a/vllm/scalar_type.py +++ b/vllm/scalar_type.py @@ -321,6 +321,9 @@ class scalar_types: # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) + # fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + float4_e2m1fn = ScalarType.float_(2, 1, True, NanRepr.NONE) + # "gptq" types uint2b2 = ScalarType.uint(2, 2) uint3b4 = ScalarType.uint(3, 4)