diff --git a/CMakeLists.txt b/CMakeLists.txt index cd1c2c90..4b569ec2 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -229,7 +229,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case. # Please keep this in sync with FetchContent_Declare line below. - set(CUTLASS_REVISION "v3.7.0" CACHE STRING "CUTLASS revision to use") + set(CUTLASS_REVISION "v3.8.0" CACHE STRING "CUTLASS revision to use") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) @@ -267,6 +267,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/permute_cols.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" "csrc/quantization/fp4/nvfp4_quant_entry.cu" + "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" "csrc/cutlass_extensions/common.cpp") @@ -383,6 +384,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS) set(SRCS "csrc/quantization/fp4/nvfp4_quant_kernels.cu" + "csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu" ) set_gencode_flags_for_srcs( SRCS "${SRCS}" diff --git a/csrc/ops.h b/csrc/ops.h index 52ccf3b5..13fbbe41 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -152,6 +152,11 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type, int64_t row); #ifndef USE_ROCM +void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A, + torch::Tensor const& B, torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha); + bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability); diff --git a/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu b/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu new file mode 100644 index 00000000..a0852c57 --- /dev/null +++ b/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#if defined ENABLE_NVFP4 && ENABLE_NVFP4 +void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A, + torch::Tensor const& B, + torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha); +#endif + +void cutlass_scaled_fp4_mm(torch::Tensor& D, torch::Tensor const& A, + torch::Tensor const& B, torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha) { +#if defined ENABLE_NVFP4 && ENABLE_NVFP4 + return cutlass_scaled_fp4_mm_sm100a(D, A, B, A_sf, B_sf, alpha); +#endif + TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 mm kernel, vLLM should " + "be compiled using CUDA 12.8 and target " + "compute capability 100 or above."); +} diff --git a/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu b/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu new file mode 100644 index 00000000..26fd9121 --- /dev/null +++ b/csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu @@ -0,0 +1,280 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +#include "cutlass_extensions/common.hpp" + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/packed_stride.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) +// Kernel Perf config +template +struct KernelTraits; + +template <> +struct KernelTraits { + using MmaTileShape = Shape<_128, _128, _256>; + using ClusterShape = Shape<_1, _1, _1>; + using PerSmTileShape_MNK = Shape<_128, _128, _256>; +}; + +template <> +struct KernelTraits { + using MmaTileShape = Shape<_256, _256, _256>; + using ClusterShape = Shape<_4, _4, _1>; + using PerSmTileShape_MNK = Shape<_128, _256, _256>; +}; + +template <> +struct KernelTraits { + using MmaTileShape = Shape<_256, _256, _256>; + using ClusterShape = Shape<_4, _4, _1>; + using PerSmTileShape_MNK = Shape<_128, _256, _256>; +}; + +template +struct Fp4GemmSm100 { + // A matrix configuration + using ElementA = cutlass::nv_float4_t; + using LayoutATag = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 32; + + // B matrix configuration + using ElementB = cutlass::nv_float4_t; + using LayoutBTag = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 32; + + // C/D matrix configuration + using ElementD = T; + using ElementC = T; + using LayoutCTag = cutlass::layout::RowMajor; + using LayoutDTag = cutlass::layout::RowMajor; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + // Kernel functional config + using ElementAccumulator = float; + using ArchTag = cutlass::arch::Sm100; + using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; + + // Kernel Perf config + using MmaTileShape = typename KernelTraits::MmaTileShape; + using ClusterShape = typename KernelTraits::ClusterShape; + using PerSmTileShape_MNK = typename KernelTraits::PerSmTileShape_MNK; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, PerSmTileShape_MNK, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, + ElementAccumulator, ElementC, LayoutCTag, AlignmentC, ElementD, + LayoutDTag, AlignmentD, + cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutATag, AlignmentA, ElementB, + LayoutBTag, AlignmentB, ElementAccumulator, MmaTileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, CollectiveMainloop, CollectiveEpilogue, void>; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + using StrideA = typename Gemm::GemmKernel::StrideA; + using LayoutA = decltype(cute::make_layout(make_shape(0, 0, 0), StrideA{})); + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using LayoutB = decltype(cute::make_layout(make_shape(0, 0, 0), StrideB{})); + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using LayoutC = decltype(cute::make_layout(make_shape(0, 0, 0), StrideC{})); + using StrideD = typename Gemm::GemmKernel::StrideD; + using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{})); +}; + +template +typename T::Gemm::Arguments args_from_options( + at::Tensor& D, at::Tensor const& A, at::Tensor const& B, + at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& alpha, + int64_t M, int64_t N, int64_t K) { + using ElementA = typename T::Gemm::ElementA; + using ElementB = typename T::Gemm::ElementB; + using ElementSFA = cutlass::float_ue4m3_t; + using ElementSFB = cutlass::float_ue4m3_t; + using ElementD = typename T::Gemm::ElementD; + using ElementCompute = float; + using StrideA = typename T::StrideA; + using StrideB = typename T::StrideB; + using StrideD = typename T::StrideD; + using Sm100BlkScaledConfig = + typename T::Gemm::GemmKernel::CollectiveMainloop::Sm100BlkScaledConfig; + + int m = static_cast(M); + int n = static_cast(N); + int k = static_cast(K); + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1}); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1}); + auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1}); + + auto layout_SFA = Sm100BlkScaledConfig::tile_atom_to_shape_SFA( + cute::make_shape(m, n, k, 1)); + auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB( + cute::make_shape(m, n, k, 1)); + + typename T::Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, 1}, + {// Mainloop arguments + static_cast(A.data_ptr()), stride_A, + static_cast(B.data_ptr()), stride_B, + static_cast(A_sf.data_ptr()), layout_SFA, + static_cast(B_sf.data_ptr()), layout_SFB}, + { // Epilogue arguments + {}, // epilogue.thread + static_cast(D.data_ptr()), + stride_D, + static_cast(D.data_ptr()), + stride_D}}; + auto& fusion_args = arguments.epilogue.thread; + fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); + return arguments; +} + +template +void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B, + at::Tensor const& A_sf, at::Tensor const& B_sf, + at::Tensor const& alpha, int64_t m, int64_t n, int64_t k, + cudaStream_t stream) { + typename Fp4GemmSm100::Gemm gemm; + + auto arguments = + args_from_options>(D, A, B, A_sf, B_sf, alpha, m, n, k); + + size_t workspace_size = Fp4GemmSm100::Gemm::get_workspace_size(arguments); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(A.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + CUTLASS_CHECK(gemm.can_implement(arguments)); + + CUTLASS_CHECK(gemm.initialize(arguments, workspace.data_ptr(), stream)); + + CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream)); +} +#else +template +void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B, + at::Tensor const& A_sf, at::Tensor const& B_sf, + at::Tensor const& alpha, int64_t m, int64_t n, int64_t k, + cudaStream_t stream) { + TORCH_CHECK(false, "Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to " + "a CUTLASS 3.8 source directory to enable support."); +} +#endif // defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + +#define CHECK_TYPE(x, st, m) \ + TORCH_CHECK(x.scalar_type() == st, "Inconsistency of Tensor type:", m) +#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x, m) \ + TORCH_CHECK(x.is_contiguous(), m, "must be contiguous") +#define CHECK_INPUT(x, st, m) \ + CHECK_TH_CUDA(x, m); \ + CHECK_CONTIGUOUS(x, m); \ + CHECK_TYPE(x, st, m) + +constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte; +constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn; + +void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A, + torch::Tensor const& B, + torch::Tensor const& A_sf, + torch::Tensor const& B_sf, + torch::Tensor const& alpha) { + CHECK_INPUT(A, FLOAT4_E2M1X2, "a"); + CHECK_INPUT(B, FLOAT4_E2M1X2, "b"); + + CHECK_INPUT(A_sf, SF_DTYPE, "scale_a"); + CHECK_INPUT(B_sf, SF_DTYPE, "scale_b"); + + CHECK_INPUT(alpha, at::ScalarType::Float, "alpha"); + + TORCH_CHECK(A.dim() == 2, "a must be a matrix"); + TORCH_CHECK(B.dim() == 2, "b must be a matrix"); + TORCH_CHECK(A.sizes()[1] == B.sizes()[1], + "a and b shapes cannot be multiplied (", A.sizes()[0], "x", + A.sizes()[1], " and ", B.sizes()[0], "x", B.sizes()[1], ")"); + + auto const m = A.sizes()[0]; + auto const n = B.sizes()[0]; + auto const k = A.sizes()[1] * 2; + + constexpr int alignment = 32; + TORCH_CHECK(k % alignment == 0, "Expected k to be divisible by ", alignment, + ", but got a shape: (", A.sizes()[0], "x", A.sizes()[1], + "), k: ", k, "."); + TORCH_CHECK(n % alignment == 0, "Expected n to be divisible by ", alignment, + ", but got b shape: (", B.sizes()[0], "x", B.sizes()[1], ")."); + + auto round_up = [](int x, int y) { return (x + y - 1) / y * y; }; + int rounded_m = round_up(m, 128); + int rounded_n = round_up(n, 128); + // Since k is divisible by 32 (alignment), k / 16 is guaranteed to be an + // integer. + int rounded_k = round_up(k / 16, 4); + + TORCH_CHECK(A_sf.dim() == 2, "scale_a must be a matrix"); + TORCH_CHECK(B_sf.dim() == 2, "scale_b must be a matrix"); + TORCH_CHECK(A_sf.sizes()[1] == B_sf.sizes()[1], + "scale_a and scale_b shapes cannot be multiplied (", + A_sf.sizes()[0], "x", A_sf.sizes()[1], " and ", B_sf.sizes()[0], + "x", B_sf.sizes()[1], ")"); + TORCH_CHECK(A_sf.sizes()[0] == rounded_m && A_sf.sizes()[1] == rounded_k, + "scale_a must be padded and swizzled to a shape (", rounded_m, + "x", rounded_k, "), but got a shape (", A_sf.sizes()[0], "x", + A_sf.sizes()[1], ")"); + TORCH_CHECK(B_sf.sizes()[0] == rounded_n && B_sf.sizes()[1] == rounded_k, + "scale_b must be padded and swizzled to a shape (", rounded_n, + "x", rounded_k, "), but got a shape (", B_sf.sizes()[0], "x", + B_sf.sizes()[1], ")"); + + auto out_dtype = D.dtype(); + at::cuda::CUDAGuard device_guard{(char)A.get_device()}; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device()); + + if (out_dtype == at::ScalarType::Half) { + runGemm(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else if (out_dtype == at::ScalarType::BFloat16) { + runGemm(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else if (out_dtype == at::ScalarType::Float) { + runGemm(D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else { + TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm"); + } +} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index d2aecba4..72de2035 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -302,6 +302,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "SymInt size_k) -> Tensor"); // conditionally compiled so impl registration is in source file + // CUTLASS nvfp4 block scaled GEMM + ops.def( + "cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b," + " Tensor block_scale_a, Tensor block_scale_b," + " Tensor alpha) -> ()"); + ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm); + // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column // quantization, as well as bias ops.def( diff --git a/tests/kernels/test_nvfp4_scaled_mm.py b/tests/kernels/test_nvfp4_scaled_mm.py new file mode 100644 index 00000000..b08026c5 --- /dev/null +++ b/tests/kernels/test_nvfp4_scaled_mm.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.platforms import current_platform +from vllm.scalar_type import scalar_types + +if not current_platform.has_device_capability(100): + pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", + allow_module_level=True) + +DTYPES = [torch.float16, torch.bfloat16] +# m, n, k +SHAPES = [(128, 128, 64), (128, 128, 128), (256, 128, 64), (128, 256, 128)] +PAD_SHAPES = [(150, 128, 64), (128, 128, 96)] +SHAPES.extend(PAD_SHAPES) + +SEEDS = [42] +CUDA_DEVICES = ['cuda:0'] + +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1fn.max() +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + +kE2M1ToFloatArray = [ + 0., + 0.5, + 1., + 1.5, + 2., + 3., + 4., + 6., +] + + +def e2m1_to_fp32(int4_value): + signBit = (int4_value & 0x8) + int4_absValue = int4_value & 0x7 + float_result = kE2M1ToFloatArray[int4_absValue] + if (signBit): + float_result = -float_result + return float_result + + +def break_fp4_bytes(a, dtype): + assert (a.dtype == torch.uint8) + m, n = a.shape + a = a.flatten() + # Get upper 4 bits + highHalfByte = (a & 0xF0) >> 4 + # Get lower 4 bits + lowHalfByte = a & 0x0F + fH = torch.tensor([e2m1_to_fp32(x) for x in highHalfByte]).to(a.device) + fL = torch.tensor([e2m1_to_fp32(x) for x in lowHalfByte]).to(a.device) + # [0xAB, 0xCD] -> [0xB, 0xA, 0xD, 0xC] + out = torch.stack((fL, fH), dim=-1).reshape(m, n * 2) + return out + + +def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): + sf_m, sf_k = a_sf_swizzled.shape + m_tiles = (m + 128 - 1) // 128 + f = block_size * 4 + k_tiles = (k + f - 1) // f + tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4)) + tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) + out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size) + return out[0:m, 0:k] + + +def dequantize_to_dtype(tensor_fp4, + tensor_sf, + global_scale, + dtype, + device, + block_size=16): + """Dequantize the fp4 tensor back to high precision.""" + # Two fp4 values are packed into one uint8. + assert tensor_fp4.dtype == torch.uint8 + m, packed_k = tensor_fp4.shape + k = packed_k * 2 + tensor_f32 = break_fp4_bytes(tensor_fp4, dtype) + tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) + tensor_sf = tensor_sf.view(torch.float8_e4m3fn) + tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) + tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale + + # scale the tensor + out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) + return out + + +def get_ref_results(a_fp4, b_fp4, a_sf, b_sf, a_global_scale, b_global_scale, + m, n, dtype, block_size, device): + _, m_k = a_fp4.shape + _, n_k = b_fp4.shape + assert (m_k == n_k) + a_in_dtype = dequantize_to_dtype(a_fp4, + a_sf, + a_global_scale, + dtype=dtype, + device=device, + block_size=block_size) + b_in_dtype = dequantize_to_dtype(b_fp4, + b_sf, + b_global_scale, + dtype=dtype, + device=device, + block_size=block_size) + return torch.matmul(a_in_dtype, b_in_dtype.t()) + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("shape", SHAPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_nvfp4_gemm( + dtype: torch.dtype, + shape: tuple[int, int, int], + seed: int, + device: str, +) -> None: + current_platform.seed_everything(seed) + m, n, packed_k = shape + k = packed_k * 2 + block_size = 16 + a_dtype = torch.randn((m, k), dtype=dtype, device=device) + b_dtype = torch.randn((n, k), dtype=dtype, device=device) + + a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / + torch.amax(a_dtype.flatten(), dim=-1)).to(torch.float32) + b_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / + torch.amax(b_dtype.flatten(), dim=-1)).to(torch.float32) + alpha = 1. / (a_global_scale * b_global_scale) + a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a_dtype, a_global_scale) + b_fp4, b_scale_interleaved = ops.scaled_fp4_quant(b_dtype, b_global_scale) + + expected_out = get_ref_results(a_fp4, b_fp4, a_scale_interleaved, + b_scale_interleaved, a_global_scale, + b_global_scale, m, n, dtype, block_size, + device) + out = ops.cutlass_scaled_fp4_mm(a_fp4, b_fp4, a_scale_interleaved, + b_scale_interleaved, alpha, dtype) + + torch.testing.assert_close(out, + expected_out.to(dtype=dtype), + atol=1e-1, + rtol=1e-1) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 2112af12..3306610a 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -433,6 +433,18 @@ if hasattr(torch.ops._C, "ggml_dequantize"): # cutlass +def cutlass_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor, + block_scale_a: torch.Tensor, + block_scale_b: torch.Tensor, alpha: torch.Tensor, + out_dtype: torch.dtype) -> torch.Tensor: + assert a.ndim == 2 and b.ndim == 2 + m, n = a.shape[0], b.shape[0] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + torch.ops._C.cutlass_scaled_fp4_mm(out, a, b, block_scale_a, block_scale_b, + alpha) + return out + + def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)