diff --git a/csrc/ops.h b/csrc/ops.h index 46007889..52ccf3b5 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -177,6 +177,10 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a, std::optional const& bias); std::vector cutlass_sparse_compress(torch::Tensor const& a); + +void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, + torch::Tensor& output_scale, + torch::Tensor const& input_scale); #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, @@ -194,10 +198,6 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); -void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, - torch::Tensor& output_scale, - torch::Tensor const& input_scale); - void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 2fd45545..ef81db14 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -385,6 +385,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "bool silu_activation," "int pad_slot_id) -> ()"); ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd); + + // Compute NVFP4 block quantized tensor. + ops.def( + "scaled_fp4_quant(Tensor! output, Tensor input," + " Tensor! output_scale, Tensor input_scale) -> ()"); + ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant); + #endif // Quantized GEMM for GPTQ. @@ -421,12 +428,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA, &dynamic_per_token_scaled_fp8_quant); - // Compute NVFP4 block quantized tensor. - ops.def( - "scaled_fp4_quant(Tensor! output, Tensor input," - " Tensor! output_scale, Tensor input_scale) -> ()"); - ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant); - // Compute int8 quantized tensor for given scaling factor. ops.def( "static_scaled_int8_quant(Tensor! result, Tensor input, Tensor scale," diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 9f2ced8f..e3e3c644 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -774,6 +774,7 @@ def scaled_fp4_quant( two values are packed into a uint8 and float8_e4m3 scaling factors in the sizzled layout. """ + assert not current_platform.is_rocm() assert input.ndim >= 1, ( f'input.ndim needs to be >= 1, but got {input.ndim}.') other_dims = 1 if input.ndim == 1 else -1