diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 9de8d5af..233fc35d 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -18,8 +18,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import * from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser -FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm( -) else torch.float8_e4m3fn +FP8_DTYPE = current_platform.fp8_dtype() class BenchmarkConfig(TypedDict): diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 03414b7e..dc6e0769 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -6,6 +6,11 @@ #include +// Need a special dispatch case macro since we will nest the FP8 dispatch. +// Instead of the usual 'scalar_t', this names the dispatched type 'fp8_t'. +#define AT_DISPATCH_FP8_CASE(enum_type, ...) \ + AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, fp8_t, __VA_ARGS__) + #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ @@ -14,17 +19,32 @@ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) -// TODO(luka/varun): use FP8_TYPE macro after refactoring -#ifndef USE_ROCM +// ROCm devices might use either fn or fnuz, so set up dispatch table for both. +// A host-based check at runtime will create a preferred FP8 type for ROCm +// such that the correct kernel is dispatched. +#ifdef USE_ROCM + #define VLLM_DISPATCH_CASE_FP8_TYPES(...) \ + AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \ + AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) + + #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) +#else + #define VLLM_DISPATCH_CASE_FP8_TYPES(...) \ + AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) + #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) -#else - #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) #endif +// When using this dispatch macro, the type is 'fp8_t' not 'scalar_t'. +// See AT_DISPATCH_FP8_CASE above. +#define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__)) + #define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__)) diff --git a/csrc/layernorm_quant_kernels.cu b/csrc/layernorm_quant_kernels.cu index c18e2a4e..d595b9e8 100644 --- a/csrc/layernorm_quant_kernels.cu +++ b/csrc/layernorm_quant_kernels.cu @@ -21,9 +21,9 @@ namespace vllm { // TODO(woosuk): Further optimize this kernel. -template +template __global__ void rms_norm_static_fp8_quant_kernel( - FP8_TYPE* __restrict__ out, // [..., hidden_size] + fp8_type* __restrict__ out, // [..., hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size] const scalar_t* __restrict__ weight, // [hidden_size] const float* __restrict__ scale, // [1] @@ -52,7 +52,7 @@ __global__ void rms_norm_static_fp8_quant_kernel( float x = (float)input[blockIdx.x * hidden_size + idx]; float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx]; out[blockIdx.x * hidden_size + idx] = - scaled_fp8_conversion(out_norm, scale_inv); + scaled_fp8_conversion(out_norm, scale_inv); } } @@ -60,10 +60,10 @@ __global__ void rms_norm_static_fp8_quant_kernel( Additional optimizations we can make in this case are packed and vectorized operations, which help with the memory latency bottleneck. */ -template +template __global__ std::enable_if_t<(width > 0) && _typeConvert::exists> fused_add_rms_norm_static_fp8_quant_kernel( - FP8_TYPE* __restrict__ out, // [..., hidden_size] + fp8_type* __restrict__ out, // [..., hidden_size] scalar_t* __restrict__ input, // [..., hidden_size] scalar_t* __restrict__ residual, // [..., hidden_size] const scalar_t* __restrict__ weight, // [hidden_size] @@ -114,7 +114,7 @@ fused_add_rms_norm_static_fp8_quant_kernel( #pragma unroll for (int i = 0; i < width; ++i) { out[id * width + i] = - scaled_fp8_conversion(float(temp.data[i]), scale_inv); + scaled_fp8_conversion(float(temp.data[i]), scale_inv); } } } @@ -122,10 +122,10 @@ fused_add_rms_norm_static_fp8_quant_kernel( /* Generic fused_add_rms_norm_kernel The width field is not used here but necessary for other specializations. */ -template +template __global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> fused_add_rms_norm_static_fp8_quant_kernel( - FP8_TYPE* __restrict__ out, // [..., hidden_size] + fp8_type* __restrict__ out, // [..., hidden_size] scalar_t* __restrict__ input, // [..., hidden_size] scalar_t* __restrict__ residual, // [..., hidden_size] const scalar_t* __restrict__ weight, // [hidden_size] @@ -158,7 +158,7 @@ fused_add_rms_norm_static_fp8_quant_kernel( float x = (float)residual[blockIdx.x * hidden_size + idx]; float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx]; out[blockIdx.x * hidden_size + idx] = - scaled_fp8_conversion(out_norm, scale_inv); + scaled_fp8_conversion(out_norm, scale_inv); } } @@ -176,25 +176,33 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size] dim3 block(std::min(hidden_size, 1024)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { - vllm::rms_norm_static_fp8_quant_kernel - <<>>( - out.data_ptr(), input.data_ptr(), - weight.data_ptr(), scale.data_ptr(), epsilon, - num_tokens, hidden_size); - }); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "rms_norm_kernel_scalar_type", [&] { + VLLM_DISPATCH_FP8_TYPES( + out.scalar_type(), "rms_norm_kernel_fp8_type", [&] { + vllm::rms_norm_static_fp8_quant_kernel + <<>>( + out.data_ptr(), input.data_ptr(), + weight.data_ptr(), scale.data_ptr(), + epsilon, num_tokens, hidden_size); + }); + }); } -#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ - VLLM_DISPATCH_FLOATING_TYPES( \ - input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \ - vllm::fused_add_rms_norm_static_fp8_quant_kernel \ - <<>>( \ - out.data_ptr(), input.data_ptr(), \ - residual.data_ptr(), weight.data_ptr(), \ - scale.data_ptr(), epsilon, num_tokens, hidden_size); \ +#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "fused_add_rms_norm_kernel_scalar_type", [&] { \ + VLLM_DISPATCH_FP8_TYPES( \ + out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \ + vllm::fused_add_rms_norm_static_fp8_quant_kernel \ + <<>>( \ + out.data_ptr(), input.data_ptr(), \ + residual.data_ptr(), \ + weight.data_ptr(), scale.data_ptr(), \ + epsilon, num_tokens, hidden_size); \ + }); \ }); - void fused_add_rms_norm_static_fp8_quant( torch::Tensor& out, // [..., hidden_size], torch::Tensor& input, // [..., hidden_size] diff --git a/csrc/quantization/fp8/amd/quant_utils.cuh b/csrc/quantization/fp8/amd/quant_utils.cuh index b812b28b..f01427cc 100644 --- a/csrc/quantization/fp8/amd/quant_utils.cuh +++ b/csrc/quantization/fp8/amd/quant_utils.cuh @@ -13,6 +13,28 @@ namespace vllm { namespace fp8 { #ifdef ENABLE_FP8 +// Use hardware cvt instruction for fp8 on rocm +template +__device__ __forceinline__ fp8_type cvt_c10(float const r) { + return {}; +} + +template <> +__device__ __forceinline__ c10::Float8_e4m3fn cvt_c10(float const r) { + return c10::Float8_e4m3fn( + __hip_cvt_float_to_fp8(r, __hip_fp8_e4m3::__default_saturation, + __hip_fp8_e4m3::__default_interpret), + c10::Float8_e4m3fn::from_bits()); +} + +template <> +__device__ __forceinline__ c10::Float8_e4m3fnuz cvt_c10(float const r) { + return c10::Float8_e4m3fnuz( + __hip_cvt_float_to_fp8(r, __hip_fp8_e4m3_fnuz::__default_saturation, + __hip_fp8_e4m3_fnuz::__default_interpret), + c10::Float8_e4m3fnuz::from_bits()); +} + template __inline__ __device__ Tout vec_conversion(const Tin& x) { return x; diff --git a/csrc/quantization/fp8/common.cu b/csrc/quantization/fp8/common.cu index e4f6615e..8f9aa21a 100644 --- a/csrc/quantization/fp8/common.cu +++ b/csrc/quantization/fp8/common.cu @@ -11,8 +11,8 @@ namespace vllm { -template -__global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out, +template +__global__ void scaled_fp8_quant_kernel(fp8_type* __restrict__ out, const scalar_t* __restrict__ input, const float* __restrict__ scale, int64_t num_elems) { @@ -25,12 +25,13 @@ __global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out, out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x); } -template +template __global__ void dynamic_per_token_scaled_fp8_quant_kernel( - FP8_TYPE* __restrict__ out, float* __restrict__ scale, + fp8_type* __restrict__ out, float* __restrict__ scale, scalar_t const* __restrict__ input, float const* __restrict__ scale_ub, const int hidden_size) { - float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f); + float const min_scaling_factor = + 1.0f / (fp8_e4m3_adjusted_max_v * 512.f); int const tid = threadIdx.x; int const token_idx = blockIdx.x; @@ -38,7 +39,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel( // Use int64 to avoid overflowing an int32 when calculating this offset int64_t offset = static_cast(token_idx) * hidden_size; scalar_t const* __restrict__ token_input = &input[offset]; - FP8_TYPE* __restrict__ token_output = &out[offset]; + fp8_type* __restrict__ token_output = &out[offset]; // For vectorization, token_input and token_output pointers need to be // aligned at 8-byte and 4-byte addresses respectively. @@ -66,7 +67,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel( token_scale = block_absmax_val_maybe; } // token scale computation - token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor); + token_scale = max(token_scale / fp8_e4m3_adjusted_max_v, + min_scaling_factor); scale[token_idx] = token_scale; } __syncthreads(); @@ -77,7 +79,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel( token_output, token_input, token_scale, hidden_size, tid, blockDim.x); } else { for (int i = tid; i < hidden_size; i += blockDim.x) { - token_output[i] = scaled_fp8_conversion( + token_output[i] = scaled_fp8_conversion( static_cast(token_input[i]), token_scale); } } @@ -96,10 +98,14 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d] const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "scaled_fp8_quant_kernel", [&] { - vllm::scaled_fp8_quant_kernel<<>>( - out.data_ptr(), input.data_ptr(), - scale.data_ptr(), num_elems); + input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] { + VLLM_DISPATCH_FP8_TYPES( + out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] { + vllm::scaled_fp8_quant_kernel + <<>>( + out.data_ptr(), input.data_ptr(), + scale.data_ptr(), num_elems); + }); }); } @@ -114,12 +120,18 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d] const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "scaled_fp8_quant_kernel", [&] { - vllm::segmented_max_reduction<<>>( - scale.data_ptr(), input.data_ptr(), num_elems); - vllm::scaled_fp8_quant_kernel<<>>( - out.data_ptr(), input.data_ptr(), - scale.data_ptr(), num_elems); + input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] { + VLLM_DISPATCH_FP8_TYPES( + out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] { + vllm::segmented_max_reduction + <<>>(scale.data_ptr(), + input.data_ptr(), + num_elems); + vllm::scaled_fp8_quant_kernel + <<>>( + out.data_ptr(), input.data_ptr(), + scale.data_ptr(), num_elems); + }); }); } @@ -138,12 +150,18 @@ void dynamic_per_token_scaled_fp8_quant( const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] { - vllm::dynamic_per_token_scaled_fp8_quant_kernel - <<>>( - out.data_ptr(), scales.data_ptr(), - input.data_ptr(), - scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, - hidden_size); + input.scalar_type(), + "dynamic_per_token_scaled_fp8_quant_kernel_scalar_type", [&] { + VLLM_DISPATCH_FP8_TYPES( + out.scalar_type(), + "dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", [&] { + vllm::dynamic_per_token_scaled_fp8_quant_kernel + <<>>( + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() + : nullptr, + hidden_size); + }); }); } diff --git a/csrc/quantization/fp8/common.cuh b/csrc/quantization/fp8/common.cuh index fac99b29..d331c63a 100644 --- a/csrc/quantization/fp8/common.cuh +++ b/csrc/quantization/fp8/common.cuh @@ -7,18 +7,52 @@ #ifndef USE_ROCM #include -using FP8_TYPE = c10::Float8_e4m3fn; -C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = - std::numeric_limits::max(); + #define MAYBE_HOST_DEVICE C10_HOST_DEVICE #else + #include + #include #include #include "amd/quant_utils.cuh" -using FP8_TYPE = c10::Float8_e4m3fnuz; -// Using the default max value from pytorch (240.0) will cause accuracy -// issue when running dynamic quantization. Here use 224.0f for rocm. -constexpr auto FP8_E4M3_MAX = 224.0f; + // ROCm doesn't seem to need C10_HOST_DEVICE for static constexpr + #define MAYBE_HOST_DEVICE #endif -constexpr static auto kFp8Type = c10::CppTypeToScalarType::value; + +// Determines the preferred FP8 type for the current platform. +// Note that for CUDA this just returns true, +// but on ROCm it will check device props. +static bool is_fp8_ocp() { +#ifndef USE_ROCM + return true; +#else + auto dprops = at::cuda::getCurrentDeviceProperties(); + std::string device_arch = dprops->gcnArchName; + size_t substring = device_arch.find("gfx94"); + return substring == std::string::npos; +#endif +} + +template +struct fp8_e4m3_adjusted_max; + +template <> +struct fp8_e4m3_adjusted_max { + static constexpr c10::Float8_e4m3fn val() { + return std::numeric_limits::max(); + } +}; + +// Using the default max value from pytorch (240.0 0x7F) will cause accuracy +// issues when running dynamic quantization. Here use 224.0 0x7E for rocm. +template <> +struct fp8_e4m3_adjusted_max { + static constexpr c10::Float8_e4m3fnuz val() { + return c10::Float8_e4m3fnuz(0x7E, c10::Float8_e4m3fnuz::from_bits()); + } +}; + +template +MAYBE_HOST_DEVICE static constexpr T fp8_e4m3_adjusted_max_v = + fp8_e4m3_adjusted_max::val(); namespace vllm { @@ -32,8 +66,8 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { return old; } -template -__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val, +template +__device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val, float const scale) { float x = 0.0f; if constexpr (is_scale_inverted) { @@ -42,15 +76,13 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val, x = val / scale; } - float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); + float r = fmax(-fp8_e4m3_adjusted_max_v, + fmin(x, fp8_e4m3_adjusted_max_v)); #ifndef USE_ROCM - return static_cast(r); + return static_cast(r); #else // Use hardware cvt instruction for fp8 on rocm - return c10::Float8_e4m3fnuz( - __hip_cvt_float_to_fp8(r, fp8::fp8_type::__default_saturation, - fp8::fp8_type::__default_interpret), - c10::Float8_e4m3fnuz::from_bits()); + return fp8::cvt_c10(r); #endif } @@ -60,7 +92,7 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val, // So to get the right answer, *scale needs to be initialized to // a value <= 0.0 and we need to wait for all thread blocks to // finish before consuming *scale. -template +template __global__ void segmented_max_reduction(float* __restrict__ scale, const scalar_t* __restrict__ input, int64_t num_elems) { @@ -91,7 +123,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale, // Finally, since cache[0] contains the maximum for this thread block, // atomically write the max to the target location if (threadIdx.x == 0) { - atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX); + atomicMaxFloat(scale, cache[0] / fp8_e4m3_adjusted_max_v); } } @@ -123,13 +155,13 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input, return absmax_val; } -template -__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out, +template +__device__ void scaled_fp8_conversion_vec(fp8_type* __restrict__ out, scalar_t const* __restrict__ input, float const scale, int64_t const num_elems, int const tid, int const step) { - using float8x4_t = q8x4_t; + using float8x4_t = q8x4_t; // Vectorized input/output to better utilize memory bandwidth. auto const* vectorized_in = reinterpret_cast const*>(input); auto* vectorized_out = reinterpret_cast(out); @@ -141,22 +173,22 @@ __device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out, vec4_t in_vec = vectorized_in[i]; float8x4_t out_vec; - out_vec.x = scaled_fp8_conversion( + out_vec.x = scaled_fp8_conversion( static_cast(in_vec.x), scale); - out_vec.y = scaled_fp8_conversion( + out_vec.y = scaled_fp8_conversion( static_cast(in_vec.y), scale); - out_vec.z = scaled_fp8_conversion( + out_vec.z = scaled_fp8_conversion( static_cast(in_vec.z), scale); - out_vec.w = scaled_fp8_conversion( + out_vec.w = scaled_fp8_conversion( static_cast(in_vec.w), scale); vectorized_out[i] = out_vec; } // Handle the remaining elements if num_elems is not divisible by 4 for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) { - out[i] = scaled_fp8_conversion( + out[i] = scaled_fp8_conversion( static_cast(input[i]), scale); } } -} // namespace vllm \ No newline at end of file +} // namespace vllm diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index 3c4f183b..1be89c50 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -144,6 +144,9 @@ void rms_norm_dynamic_per_token_quant( torch::Tensor& scales, // [num_tokens] double const var_epsilon, // Variance epsilon used in norm calculation std::optional scale_ub, std::optional residual) { + static c10::ScalarType kFp8Type = is_fp8_ocp() + ? c10::ScalarType::Float8_e4m3fn + : c10::ScalarType::Float8_e4m3fnuz; TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8); TORCH_CHECK(out.is_contiguous() && input.is_contiguous()); diff --git a/csrc/quantization/fused_kernels/quant_conversions.cuh b/csrc/quantization/fused_kernels/quant_conversions.cuh index f8a98722..9ac7b188 100644 --- a/csrc/quantization/fused_kernels/quant_conversions.cuh +++ b/csrc/quantization/fused_kernels/quant_conversions.cuh @@ -31,9 +31,11 @@ static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) { #endif } -static __device__ __forceinline__ FP8_TYPE float_to_fp8(float const x) { - float const r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); - return static_cast(r); +template +static __device__ __forceinline__ fp8_type float_to_fp8(float const x) { + float const r = fmax(-fp8_e4m3_adjusted_max_v, + fmin(x, fp8_e4m3_adjusted_max_v)); + return static_cast(r); } template @@ -54,15 +56,16 @@ struct ScaledQuant< }; template -struct ScaledQuant< - quant_type_t, is_scale_inverted, - typename std::enable_if_t>> { +struct ScaledQuant || + std::is_same_v>> { static __device__ __forceinline__ quant_type_t quant_fn(float const x, float const scale) { if constexpr (is_scale_inverted) { - return float_to_fp8(x * scale); + return float_to_fp8(x * scale); } else { - return float_to_fp8(x / scale); + return float_to_fp8(x / scale); } } }; diff --git a/csrc/quantization/vectorization.cuh b/csrc/quantization/vectorization.cuh index 44c99913..866da10b 100644 --- a/csrc/quantization/vectorization.cuh +++ b/csrc/quantization/vectorization.cuh @@ -4,7 +4,6 @@ */ // Include both AMD and NVIDIA fp8 types to avoid circular import -// TODO(luka/varun) use FP8_TYPE instead after refactoring #include #include diff --git a/tests/kernels/quant_utils.py b/tests/kernels/quant_utils.py index a21d642b..498da600 100644 --- a/tests/kernels/quant_utils.py +++ b/tests/kernels/quant_utils.py @@ -9,8 +9,7 @@ from vllm.platforms import current_platform # Using the default value (240.0) from pytorch will cause accuracy # issue on dynamic quantization models. Here use 224.0 for rocm. ROCM_FP8_MAX = 224.0 -FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm() \ - else torch.float8_e4m3fn +FP8_DTYPE = current_platform.fp8_dtype() def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: diff --git a/tests/kernels/test_triton_scaled_mm.py b/tests/kernels/test_triton_scaled_mm.py index bbff3e0a..45f10b0e 100644 --- a/tests/kernels/test_triton_scaled_mm.py +++ b/tests/kernels/test_triton_scaled_mm.py @@ -32,11 +32,8 @@ def scaled_mm_torch(a: torch.Tensor, def get_8bit_types(): types = [torch.int8] - supports_fp8 = current_platform.has_device_capability(89) - if current_platform.is_rocm() and supports_fp8: - types.append(torch.float8_e4m3fnuz) - elif current_platform.is_cuda() and supports_fp8: - types.append(torch.float8_e4m3fn) + if current_platform.supports_fp8(): + types.append(current_platform.fp8_dtype()) return types diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index 3a7f0a19..b9a1d759 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -103,8 +103,7 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool, assert attn._v_scale == 1.0 if current_platform.is_cuda(): - if current_platform.has_device_capability( - 89) and not force_marlin: + if current_platform.supports_fp8() and not force_marlin: # For GPUs with hardware support, we keep weights in fp8 assert fc1.weight.dtype == torch.float8_e4m3fn else: @@ -112,11 +111,9 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool, # for weight-only quantization using Marlin kernels assert fc1.weight.dtype == torch.int32 elif current_platform.is_rocm(): - # Only MI300 and above support quantization='fp8' - if current_platform.has_device_capability( - 94) and not force_marlin: + if current_platform.supports_fp8() and not force_marlin: # For GPUs with hardware support, we keep weights in fp8 - assert fc1.weight.dtype == torch.float8_e4m3fnuz + assert fc1.weight.dtype == current_platform.fp8_dtype() else: # unsupported ROCm platform pytest.skip( "Skip `test_load_fp16_model`. " diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 53065dd0..14cfe751 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -478,16 +478,16 @@ def cutlass_scaled_mm(a: torch.Tensor, out_dtype: torch.dtype, bias: Optional[torch.Tensor] = None) -> torch.Tensor: """ - `cutlass_scaled_mm` implements a fused version of + `cutlass_scaled_mm` implements a fused version of `output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)` - where scale_a * a and scale_b * b are implemented using numpy-style - broadcasting. - - In order to support blockwise scaling like found in DeepSeek V3 we also - support extended "group" broadcast rules. We extend the numpy-style - broadcasting rules with the following rule: - "if the extent of a dimension in the source shape is between 1 and - corresponding extent in the target shape we repeat each element along + where scale_a * a and scale_b * b are implemented using numpy-style + broadcasting. + + In order to support blockwise scaling like found in DeepSeek V3 we also + support extended "group" broadcast rules. We extend the numpy-style + broadcasting rules with the following rule: + "if the extent of a dimension in the source shape is between 1 and + corresponding extent in the target shape we repeat each element along that dimension src_shape[dim] // target_shape[dim] times consecutively" example if we have: a = [[1, 2], and target_shape = (2, 4) @@ -564,7 +564,7 @@ def cutlass_sparse_compress(a: torch.Tensor) \ with Cutlass sparse kernels. Args: - a (torch.Tensor): + a (torch.Tensor): The input tensor to be compressed. Must have one of the following data types: - `torch.int8` - `torch.float8_e4m3fn` @@ -572,7 +572,7 @@ def cutlass_sparse_compress(a: torch.Tensor) \ - `torch.float16` Returns: - tuple[torch.Tensor, torch.Tensor]: + tuple[torch.Tensor, torch.Tensor]: A tuple containing: - `a_nzs` (torch.Tensor): A tensor containing non-zero elements of `a`. - `a_meta` (torch.Tensor): A tensor containing metadata for the sparse representation. @@ -875,9 +875,8 @@ def scaled_fp8_quant( # This code assumes batch_dim and num_tokens are flattened assert (input.ndim == 2) shape: Union[tuple[int, int], torch.Size] = input.shape - # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - if current_platform.is_rocm() else torch.float8_e4m3fn + # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz + out_dtype: torch.dtype = current_platform.fp8_dtype() if num_token_padding: shape = (max(num_token_padding, input.shape[0]), shape[1]) output = torch.empty(shape, device=input.device, dtype=out_dtype) @@ -908,7 +907,7 @@ def allspark_repack_weight( has_zp: bool = False ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ - Rearrange qweight, scale, and zero_point(if asymmetric) to n32k16 format + Rearrange qweight, scale, and zero_point(if asymmetric) to n32k16 format for Ampere W8A16 Fused Gemm kernel Args: @@ -917,10 +916,10 @@ def allspark_repack_weight( zero_point: fp16/bf16 weight zero_point tensor, 1 x n format. Must be provided for asymmetric quantization. has_zp: if use symmetric quantization, has_zp = False. - if use asymmetric quantization, has_zp = True. - + if use asymmetric quantization, has_zp = True. + Returns: - tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : rearranged weight, scale, and optionally zero_point. """ K = qweight.shape[0] diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 4f4b70cd..e912b1e9 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -226,7 +226,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsW8A8Fp8) from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - Fp8LinearGenericOp, current_platform_fp8_dtype, is_fp8) + Fp8LinearGenericOp, is_fp8) from vllm.model_executor.layers.quantization.utils.quant_utils import ( scaled_quantize) from vllm.model_executor.layers.rotary_embedding import ( @@ -1238,7 +1238,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): W_Q_UK, W_Q_UK_scales = scaled_quantize( W_Q_UK, self.reqaunt_weight_group_shape, - quant_dtype=current_platform_fp8_dtype) + quant_dtype=current_platform.fp8_dtype()) # For FP8 save the transpose so we can use # `apply_w8a8_block_fp8_linear` directly self.W_Q_UK = W_Q_UK.T.contiguous() @@ -1255,7 +1255,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): W_UV_O, W_UV_O_scales = scaled_quantize( W_UV_O, self.reqaunt_weight_group_shape, - quant_dtype=current_platform_fp8_dtype) + quant_dtype=current_platform.fp8_dtype()) # For FP8 save the transpose so we can use # `apply_w8a8_block_fp8_linear` directly self.W_UV_O = W_UV_O.T.contiguous() diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index c9aa0ec2..ff381a4c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -158,8 +158,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): layer.w2_input_scale = torch.nn.Parameter( layer.w2_input_scale.max(), requires_grad=False) - # If rocm, normalize the weights and scales to e4m3fnuz - if current_platform.is_rocm(): + if current_platform.is_fp8_fnuz(): # Normalize the weights and scales w13_weight, w13_weight_scale, w13_input_scale = \ normalize_e4m3fn_to_e4m3fnuz( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index aca25c9b..27a74d67 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -42,7 +42,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): logical_widths=layer.logical_widths, ) - if current_platform.is_rocm(): + if current_platform.is_fp8_fnuz(): input_scale = getattr(layer, 'input_scale', None) weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( @@ -60,7 +60,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): elif self.strategy == QuantizationStrategy.CHANNEL: weight = layer.weight - if current_platform.is_rocm(): + if current_platform.is_fp8_fnuz(): input_scale = getattr(layer, 'input_scale', None) weight, weight_scale, input_scale = \ diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index 110e4ef2..1cc431c5 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -127,7 +127,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): weight = layer.weight - if current_platform.is_rocm(): + if current_platform.is_fp8_fnuz(): weight, weight_scale, input_scale = \ normalize_e4m3fn_to_e4m3fnuz( weight=weight, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 3f8e0a2f..2d5d8e6a 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -270,7 +270,7 @@ class Fp8LinearMethod(LinearMethodBase): # TODO(rob): refactor block quant into separate class. if self.block_quant: assert self.quant_config.activation_scheme == "dynamic" - if current_platform.is_rocm(): + if current_platform.is_fp8_fnuz(): weight, weight_scale_inv, _ = \ normalize_e4m3fn_to_e4m3fnuz( weight=layer.weight, @@ -327,8 +327,7 @@ class Fp8LinearMethod(LinearMethodBase): weight = layer.weight weight_scale = layer.weight_scale - # If rocm, use float8_e4m3fnuz. - if current_platform.is_rocm(): + if current_platform.is_fp8_fnuz(): weight, weight_scale, input_scale = \ normalize_e4m3fn_to_e4m3fnuz( weight=weight, @@ -533,7 +532,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): # TODO (rob): refactor block quant into separate class. if self.block_quant: assert self.quant_config.activation_scheme == "dynamic" - if current_platform.is_rocm(): + if current_platform.is_fp8_fnuz(): w13_weight, w13_weight_scale_inv, w13_input_scale = \ normalize_e4m3fn_to_e4m3fnuz( layer.w13_weight, layer.w13_weight_scale_inv, @@ -559,9 +558,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): # If checkpoint is fp16, quantize in place. if not self.quant_config.is_checkpoint_fp8_serialized: - # If rocm, use float8_e4m3fnuz as dtype - fp8_dtype = torch.float8_e4m3fnuz \ - if current_platform.is_rocm() else torch.float8_e4m3fn + fp8_dtype = current_platform.fp8_dtype() w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) @@ -608,8 +605,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.w13_input_scale.max(), requires_grad=False) layer.w2_input_scale = torch.nn.Parameter( layer.w2_input_scale.max(), requires_grad=False) - # If rocm, normalize the weights and scales to e4m3fnuz - if current_platform.is_rocm(): + if current_platform.is_fp8_fnuz(): # Normalize the weights and scales w13_weight, w13_weight_scale, w13_input_scale = \ normalize_e4m3fn_to_e4m3fnuz( diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 32dce5aa..bc26a455 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -142,8 +142,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): layer.w2_input_scale = torch.nn.Parameter( layer.w2_input_scale.max(), requires_grad=False) - # If rocm, normalize the weights and scales to e4m3fnuz - if current_platform.is_rocm(): + if current_platform.is_fp8_fnuz(): # Normalize the weights and scales w13_weight, w13_weight_scale, w13_input_scale = \ normalize_e4m3fn_to_e4m3fnuz( diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index 7676fbdd..3e4251e4 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -39,7 +39,7 @@ class QuarkW8A8Fp8(QuarkScheme): logical_widths=layer.logical_widths, ) - if current_platform.is_rocm(): + if current_platform.is_fp8_fnuz(): weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight=weight, weight_scale=max_w_scale, @@ -55,7 +55,7 @@ class QuarkW8A8Fp8(QuarkScheme): elif self.qscheme == "per_channel": weight = layer.weight - if current_platform.is_rocm(): + if current_platform.is_fp8_fnuz(): weight, weight_scale, input_scale = \ normalize_e4m3fn_to_e4m3fnuz( weight=weight, diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 62569185..1e19302c 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -22,10 +22,6 @@ from vllm.utils import direct_register_custom_op logger = init_logger(__name__) -current_platform_fp8_dtype = (torch.float8_e4m3fnuz - if current_platform.is_rocm() else - torch.float8_e4m3fn) - def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool: if isinstance(x, torch.Tensor): @@ -165,9 +161,7 @@ def input_to_float8( ) -> Tuple[torch.Tensor, torch.Tensor]: """This function quantizes input values to float8 values " "with tensor-wise quantization.""" - if dtype is None: - dtype = (torch.float8_e4m3fnuz - if current_platform.is_rocm() else torch.float8_e4m3fn) + dtype = current_platform.fp8_dtype() if dtype is None else dtype finfo = torch.finfo(dtype) min_val, max_val = x.aminmax() amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) @@ -311,9 +305,7 @@ def per_token_group_quant_fp8( Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. """ - if dtype is None: - dtype = (torch.float8_e4m3fnuz - if current_platform.is_rocm() else torch.float8_e4m3fn) + dtype = current_platform.fp8_dtype() if dtype is None else dtype assert (x.shape[-1] % group_size == 0), ( f"the last dimension of `x` {x.shape[-1]} must be divisible " f"by `group_size` {group_size}") diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index a986ec0a..38975843 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -293,6 +293,10 @@ class CudaPlatformBase(Platform): def get_device_communicator_cls(cls) -> str: return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa + @classmethod + def supports_fp8(cls) -> bool: + return cls.has_device_capability(89) + # NVML utils # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index e7e55e11..7415b5d5 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -330,6 +330,36 @@ class Platform: """ return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa + @classmethod + def supports_fp8(cls) -> bool: + """ + Returns whether the current platform supports FP8 types. + """ + return False + + @classmethod + def is_fp8_fnuz(cls) -> bool: + """ + Returns whether the preferred FP8 type is FNUZ on the current platform. + + There are two representations of FP8, OCP FP8 and FNUZ FP8. + The OCP specification can be found at https://tinyurl.com/b7jvwpft. + The FNUZ specification can be found at https://tinyurl.com/5n6hwwu5. + + AMD's MI300 and MI325 have native hardware support for FNUZ. All other + hardware has converged on the OCP FP8 standard. + """ + return False + + @classmethod + def fp8_dtype(cls) -> torch.dtype: + """ + Returns the preferred FP8 type on the current platform. + + See the documentation for is_fp8_fnuz for details. + """ + return torch.float8_e4m3fn + @classmethod def use_all_gather(cls) -> bool: """ diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index de4f6070..75f287b5 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -231,3 +231,20 @@ class RocmPlatform(Platform): @classmethod def get_device_communicator_cls(cls) -> str: return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa + + @classmethod + def supports_fp8(cls) -> bool: + gcn_arch = torch.cuda.get_device_properties(0).gcnArchName + return any(gfx in gcn_arch for gfx in ['gfx94', 'gfx95', 'gfx12']) + + @classmethod + def is_fp8_fnuz(cls) -> bool: + # only device 0 is checked, this assumes MI300 platforms are homogeneous + return 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName + + @classmethod + def fp8_dtype(cls) -> torch.dtype: + if cls.is_fp8_fnuz(): + return torch.float8_e4m3fnuz + else: + return torch.float8_e4m3fn diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 526b792a..14a7bd35 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -219,7 +219,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsW8A8Fp8) from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - Fp8LinearGenericOp, current_platform_fp8_dtype, is_fp8) + Fp8LinearGenericOp, is_fp8) from vllm.model_executor.layers.quantization.utils.quant_utils import ( scaled_quantize) from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding @@ -826,7 +826,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): W_Q_UK, W_Q_UK_scales = scaled_quantize( W_Q_UK, self.reqaunt_weight_group_shape, - quant_dtype=current_platform_fp8_dtype) + quant_dtype=current_platform.fp8_dtype()) # For FP8 save the transpose so we can use # `apply_w8a8_block_fp8_linear` directly self.W_Q_UK = W_Q_UK.T.contiguous() @@ -843,7 +843,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): W_UV_O, W_UV_O_scales = scaled_quantize( W_UV_O, self.reqaunt_weight_group_shape, - quant_dtype=current_platform_fp8_dtype) + quant_dtype=current_platform.fp8_dtype()) # For FP8 save the transpose so we can use # `apply_w8a8_block_fp8_linear` directly self.W_UV_O = W_UV_O.T.contiguous()