dynamic distpatch of fp8 kernels (#14245)

Signed-off-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
Jeff Daily 2025-03-11 07:54:56 -07:00 committed by GitHub
parent 08a1a1121d
commit a1c8f3796c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 292 additions and 159 deletions

View File

@ -18,8 +18,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm( FP8_DTYPE = current_platform.fp8_dtype()
) else torch.float8_e4m3fn
class BenchmarkConfig(TypedDict): class BenchmarkConfig(TypedDict):

View File

@ -6,6 +6,11 @@
#include <torch/all.h> #include <torch/all.h>
// Need a special dispatch case macro since we will nest the FP8 dispatch.
// Instead of the usual 'scalar_t', this names the dispatched type 'fp8_t'.
#define AT_DISPATCH_FP8_CASE(enum_type, ...) \
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, fp8_t, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
@ -14,17 +19,32 @@
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
// TODO(luka/varun): use FP8_TYPE macro after refactoring // ROCm devices might use either fn or fnuz, so set up dispatch table for both.
#ifndef USE_ROCM // A host-based check at runtime will create a preferred FP8 type for ROCm
// such that the correct kernel is dispatched.
#ifdef USE_ROCM
#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
#else
#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__)
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
#else
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
#endif #endif
// When using this dispatch macro, the type is 'fp8_t' not 'scalar_t'.
// See AT_DISPATCH_FP8_CASE above.
#define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \ #define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__)) AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))

View File

@ -21,9 +21,9 @@
namespace vllm { namespace vllm {
// TODO(woosuk): Further optimize this kernel. // TODO(woosuk): Further optimize this kernel.
template <typename scalar_t> template <typename scalar_t, typename fp8_type>
__global__ void rms_norm_static_fp8_quant_kernel( __global__ void rms_norm_static_fp8_quant_kernel(
FP8_TYPE* __restrict__ out, // [..., hidden_size] fp8_type* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size] const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
const float* __restrict__ scale, // [1] const float* __restrict__ scale, // [1]
@ -52,7 +52,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(
float x = (float)input[blockIdx.x * hidden_size + idx]; float x = (float)input[blockIdx.x * hidden_size + idx];
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx]; float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
out[blockIdx.x * hidden_size + idx] = out[blockIdx.x * hidden_size + idx] =
scaled_fp8_conversion<true>(out_norm, scale_inv); scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
} }
} }
@ -60,10 +60,10 @@ __global__ void rms_norm_static_fp8_quant_kernel(
Additional optimizations we can make in this case are Additional optimizations we can make in this case are
packed and vectorized operations, which help with the packed and vectorized operations, which help with the
memory latency bottleneck. */ memory latency bottleneck. */
template <typename scalar_t, int width> template <typename scalar_t, int width, typename fp8_type>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists> __global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
fused_add_rms_norm_static_fp8_quant_kernel( fused_add_rms_norm_static_fp8_quant_kernel(
FP8_TYPE* __restrict__ out, // [..., hidden_size] fp8_type* __restrict__ out, // [..., hidden_size]
scalar_t* __restrict__ input, // [..., hidden_size] scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size] scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
@ -114,7 +114,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
#pragma unroll #pragma unroll
for (int i = 0; i < width; ++i) { for (int i = 0; i < width; ++i) {
out[id * width + i] = out[id * width + i] =
scaled_fp8_conversion<true>(float(temp.data[i]), scale_inv); scaled_fp8_conversion<true, fp8_type>(float(temp.data[i]), scale_inv);
} }
} }
} }
@ -122,10 +122,10 @@ fused_add_rms_norm_static_fp8_quant_kernel(
/* Generic fused_add_rms_norm_kernel /* Generic fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations. The width field is not used here but necessary for other specializations.
*/ */
template <typename scalar_t, int width> template <typename scalar_t, int width, typename fp8_type>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists> __global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
fused_add_rms_norm_static_fp8_quant_kernel( fused_add_rms_norm_static_fp8_quant_kernel(
FP8_TYPE* __restrict__ out, // [..., hidden_size] fp8_type* __restrict__ out, // [..., hidden_size]
scalar_t* __restrict__ input, // [..., hidden_size] scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size] scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size] const scalar_t* __restrict__ weight, // [hidden_size]
@ -158,7 +158,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
float x = (float)residual[blockIdx.x * hidden_size + idx]; float x = (float)residual[blockIdx.x * hidden_size + idx];
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx]; float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
out[blockIdx.x * hidden_size + idx] = out[blockIdx.x * hidden_size + idx] =
scaled_fp8_conversion<true>(out_norm, scale_inv); scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
} }
} }
@ -176,25 +176,33 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
dim3 block(std::min(hidden_size, 1024)); dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { VLLM_DISPATCH_FLOATING_TYPES(
vllm::rms_norm_static_fp8_quant_kernel<scalar_t> input.scalar_type(), "rms_norm_kernel_scalar_type", [&] {
<<<grid, block, 0, stream>>>( VLLM_DISPATCH_FP8_TYPES(
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(), out.scalar_type(), "rms_norm_kernel_fp8_type", [&] {
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), epsilon, vllm::rms_norm_static_fp8_quant_kernel<scalar_t, fp8_t>
num_tokens, hidden_size); <<<grid, block, 0, stream>>>(
}); out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(),
epsilon, num_tokens, hidden_size);
});
});
} }
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ #define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \ VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \ input.scalar_type(), "fused_add_rms_norm_kernel_scalar_type", [&] { \
vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, width> \ VLLM_DISPATCH_FP8_TYPES( \
<<<grid, block, 0, stream>>>( \ out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(), \ vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, \
residual.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \ width, fp8_t> \
scale.data_ptr<float>(), epsilon, num_tokens, hidden_size); \ <<<grid, block, 0, stream>>>( \
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), \
epsilon, num_tokens, hidden_size); \
}); \
}); });
void fused_add_rms_norm_static_fp8_quant( void fused_add_rms_norm_static_fp8_quant(
torch::Tensor& out, // [..., hidden_size], torch::Tensor& out, // [..., hidden_size],
torch::Tensor& input, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size]

View File

@ -13,6 +13,28 @@ namespace vllm {
namespace fp8 { namespace fp8 {
#ifdef ENABLE_FP8 #ifdef ENABLE_FP8
// Use hardware cvt instruction for fp8 on rocm
template <typename fp8_type>
__device__ __forceinline__ fp8_type cvt_c10(float const r) {
return {};
}
template <>
__device__ __forceinline__ c10::Float8_e4m3fn cvt_c10(float const r) {
return c10::Float8_e4m3fn(
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3::__default_saturation,
__hip_fp8_e4m3::__default_interpret),
c10::Float8_e4m3fn::from_bits());
}
template <>
__device__ __forceinline__ c10::Float8_e4m3fnuz cvt_c10(float const r) {
return c10::Float8_e4m3fnuz(
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3_fnuz::__default_saturation,
__hip_fp8_e4m3_fnuz::__default_interpret),
c10::Float8_e4m3fnuz::from_bits());
}
template <typename Tout, typename Tin> template <typename Tout, typename Tin>
__inline__ __device__ Tout vec_conversion(const Tin& x) { __inline__ __device__ Tout vec_conversion(const Tin& x) {
return x; return x;

View File

@ -11,8 +11,8 @@
namespace vllm { namespace vllm {
template <typename scalar_t> template <typename scalar_t, typename fp8_type>
__global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out, __global__ void scaled_fp8_quant_kernel(fp8_type* __restrict__ out,
const scalar_t* __restrict__ input, const scalar_t* __restrict__ input,
const float* __restrict__ scale, const float* __restrict__ scale,
int64_t num_elems) { int64_t num_elems) {
@ -25,12 +25,13 @@ __global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x); out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
} }
template <typename scalar_t> template <typename scalar_t, typename fp8_type>
__global__ void dynamic_per_token_scaled_fp8_quant_kernel( __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
FP8_TYPE* __restrict__ out, float* __restrict__ scale, fp8_type* __restrict__ out, float* __restrict__ scale,
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub, scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
const int hidden_size) { const int hidden_size) {
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f); float const min_scaling_factor =
1.0f / (fp8_e4m3_adjusted_max_v<fp8_type> * 512.f);
int const tid = threadIdx.x; int const tid = threadIdx.x;
int const token_idx = blockIdx.x; int const token_idx = blockIdx.x;
@ -38,7 +39,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
// Use int64 to avoid overflowing an int32 when calculating this offset // Use int64 to avoid overflowing an int32 when calculating this offset
int64_t offset = static_cast<int64_t>(token_idx) * hidden_size; int64_t offset = static_cast<int64_t>(token_idx) * hidden_size;
scalar_t const* __restrict__ token_input = &input[offset]; scalar_t const* __restrict__ token_input = &input[offset];
FP8_TYPE* __restrict__ token_output = &out[offset]; fp8_type* __restrict__ token_output = &out[offset];
// For vectorization, token_input and token_output pointers need to be // For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively. // aligned at 8-byte and 4-byte addresses respectively.
@ -66,7 +67,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
token_scale = block_absmax_val_maybe; token_scale = block_absmax_val_maybe;
} }
// token scale computation // token scale computation
token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor); token_scale = max(token_scale / fp8_e4m3_adjusted_max_v<fp8_type>,
min_scaling_factor);
scale[token_idx] = token_scale; scale[token_idx] = token_scale;
} }
__syncthreads(); __syncthreads();
@ -77,7 +79,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
token_output, token_input, token_scale, hidden_size, tid, blockDim.x); token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
} else { } else {
for (int i = tid; i < hidden_size; i += blockDim.x) { for (int i = tid; i < hidden_size; i += blockDim.x) {
token_output[i] = scaled_fp8_conversion<false>( token_output[i] = scaled_fp8_conversion<false, fp8_type>(
static_cast<float>(token_input[i]), token_scale); static_cast<float>(token_input[i]), token_scale);
} }
} }
@ -96,10 +98,14 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "scaled_fp8_quant_kernel", [&] { input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>( VLLM_DISPATCH_FP8_TYPES(
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(), out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
scale.data_ptr<float>(), num_elems); vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
scale.data_ptr<float>(), num_elems);
});
}); });
} }
@ -114,12 +120,18 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "scaled_fp8_quant_kernel", [&] { input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>( VLLM_DISPATCH_FP8_TYPES(
scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems); out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>( vllm::segmented_max_reduction<scalar_t, fp8_t>
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(), <<<grid, block, 0, stream>>>(scale.data_ptr<float>(),
scale.data_ptr<float>(), num_elems); input.data_ptr<scalar_t>(),
num_elems);
vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
scale.data_ptr<float>(), num_elems);
});
}); });
} }
@ -138,12 +150,18 @@ void dynamic_per_token_scaled_fp8_quant(
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] { input.scalar_type(),
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t> "dynamic_per_token_scaled_fp8_quant_kernel_scalar_type", [&] {
<<<grid, block, 0, stream>>>( VLLM_DISPATCH_FP8_TYPES(
out.data_ptr<FP8_TYPE>(), scales.data_ptr<float>(), out.scalar_type(),
input.data_ptr<scalar_t>(), "dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", [&] {
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr, vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
hidden_size); <<<grid, block, 0, stream>>>(
out.data_ptr<fp8_t>(), scales.data_ptr<float>(),
input.data_ptr<scalar_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>()
: nullptr,
hidden_size);
});
}); });
} }

View File

@ -7,18 +7,52 @@
#ifndef USE_ROCM #ifndef USE_ROCM
#include <c10/util/Float8_e4m3fn.h> #include <c10/util/Float8_e4m3fn.h>
using FP8_TYPE = c10::Float8_e4m3fn; #define MAYBE_HOST_DEVICE C10_HOST_DEVICE
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
std::numeric_limits<FP8_TYPE>::max();
#else #else
#include <ATen/hip/HIPContext.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e4m3fnuz.h> #include <c10/util/Float8_e4m3fnuz.h>
#include "amd/quant_utils.cuh" #include "amd/quant_utils.cuh"
using FP8_TYPE = c10::Float8_e4m3fnuz; // ROCm doesn't seem to need C10_HOST_DEVICE for static constexpr
// Using the default max value from pytorch (240.0) will cause accuracy #define MAYBE_HOST_DEVICE
// issue when running dynamic quantization. Here use 224.0f for rocm.
constexpr auto FP8_E4M3_MAX = 224.0f;
#endif #endif
constexpr static auto kFp8Type = c10::CppTypeToScalarType<FP8_TYPE>::value;
// Determines the preferred FP8 type for the current platform.
// Note that for CUDA this just returns true,
// but on ROCm it will check device props.
static bool is_fp8_ocp() {
#ifndef USE_ROCM
return true;
#else
auto dprops = at::cuda::getCurrentDeviceProperties();
std::string device_arch = dprops->gcnArchName;
size_t substring = device_arch.find("gfx94");
return substring == std::string::npos;
#endif
}
template <typename T>
struct fp8_e4m3_adjusted_max;
template <>
struct fp8_e4m3_adjusted_max<c10::Float8_e4m3fn> {
static constexpr c10::Float8_e4m3fn val() {
return std::numeric_limits<c10::Float8_e4m3fn>::max();
}
};
// Using the default max value from pytorch (240.0 0x7F) will cause accuracy
// issues when running dynamic quantization. Here use 224.0 0x7E for rocm.
template <>
struct fp8_e4m3_adjusted_max<c10::Float8_e4m3fnuz> {
static constexpr c10::Float8_e4m3fnuz val() {
return c10::Float8_e4m3fnuz(0x7E, c10::Float8_e4m3fnuz::from_bits());
}
};
template <typename T>
MAYBE_HOST_DEVICE static constexpr T fp8_e4m3_adjusted_max_v =
fp8_e4m3_adjusted_max<T>::val();
namespace vllm { namespace vllm {
@ -32,8 +66,8 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
return old; return old;
} }
template <bool is_scale_inverted> template <bool is_scale_inverted, typename fp8_type>
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val, __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
float const scale) { float const scale) {
float x = 0.0f; float x = 0.0f;
if constexpr (is_scale_inverted) { if constexpr (is_scale_inverted) {
@ -42,15 +76,13 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
x = val / scale; x = val / scale;
} }
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); float r = fmax(-fp8_e4m3_adjusted_max_v<fp8_type>,
fmin(x, fp8_e4m3_adjusted_max_v<fp8_type>));
#ifndef USE_ROCM #ifndef USE_ROCM
return static_cast<c10::Float8_e4m3fn>(r); return static_cast<fp8_type>(r);
#else #else
// Use hardware cvt instruction for fp8 on rocm // Use hardware cvt instruction for fp8 on rocm
return c10::Float8_e4m3fnuz( return fp8::cvt_c10<fp8_type>(r);
__hip_cvt_float_to_fp8(r, fp8::fp8_type::__default_saturation,
fp8::fp8_type::__default_interpret),
c10::Float8_e4m3fnuz::from_bits());
#endif #endif
} }
@ -60,7 +92,7 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
// So to get the right answer, *scale needs to be initialized to // So to get the right answer, *scale needs to be initialized to
// a value <= 0.0 and we need to wait for all thread blocks to // a value <= 0.0 and we need to wait for all thread blocks to
// finish before consuming *scale. // finish before consuming *scale.
template <typename scalar_t> template <typename scalar_t, typename fp8_type>
__global__ void segmented_max_reduction(float* __restrict__ scale, __global__ void segmented_max_reduction(float* __restrict__ scale,
const scalar_t* __restrict__ input, const scalar_t* __restrict__ input,
int64_t num_elems) { int64_t num_elems) {
@ -91,7 +123,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
// Finally, since cache[0] contains the maximum for this thread block, // Finally, since cache[0] contains the maximum for this thread block,
// atomically write the max to the target location // atomically write the max to the target location
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX); atomicMaxFloat(scale, cache[0] / fp8_e4m3_adjusted_max_v<fp8_type>);
} }
} }
@ -123,13 +155,13 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
return absmax_val; return absmax_val;
} }
template <typename scalar_t, bool is_scale_inverted> template <typename scalar_t, bool is_scale_inverted, typename fp8_type>
__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out, __device__ void scaled_fp8_conversion_vec(fp8_type* __restrict__ out,
scalar_t const* __restrict__ input, scalar_t const* __restrict__ input,
float const scale, float const scale,
int64_t const num_elems, int64_t const num_elems,
int const tid, int const step) { int const tid, int const step) {
using float8x4_t = q8x4_t<FP8_TYPE>; using float8x4_t = q8x4_t<fp8_type>;
// Vectorized input/output to better utilize memory bandwidth. // Vectorized input/output to better utilize memory bandwidth.
auto const* vectorized_in = reinterpret_cast<vec4_t<scalar_t> const*>(input); auto const* vectorized_in = reinterpret_cast<vec4_t<scalar_t> const*>(input);
auto* vectorized_out = reinterpret_cast<float8x4_t*>(out); auto* vectorized_out = reinterpret_cast<float8x4_t*>(out);
@ -141,22 +173,22 @@ __device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
vec4_t<scalar_t> in_vec = vectorized_in[i]; vec4_t<scalar_t> in_vec = vectorized_in[i];
float8x4_t out_vec; float8x4_t out_vec;
out_vec.x = scaled_fp8_conversion<is_scale_inverted>( out_vec.x = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
static_cast<float>(in_vec.x), scale); static_cast<float>(in_vec.x), scale);
out_vec.y = scaled_fp8_conversion<is_scale_inverted>( out_vec.y = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
static_cast<float>(in_vec.y), scale); static_cast<float>(in_vec.y), scale);
out_vec.z = scaled_fp8_conversion<is_scale_inverted>( out_vec.z = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
static_cast<float>(in_vec.z), scale); static_cast<float>(in_vec.z), scale);
out_vec.w = scaled_fp8_conversion<is_scale_inverted>( out_vec.w = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
static_cast<float>(in_vec.w), scale); static_cast<float>(in_vec.w), scale);
vectorized_out[i] = out_vec; vectorized_out[i] = out_vec;
} }
// Handle the remaining elements if num_elems is not divisible by 4 // Handle the remaining elements if num_elems is not divisible by 4
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) { for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
out[i] = scaled_fp8_conversion<is_scale_inverted>( out[i] = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
static_cast<float>(input[i]), scale); static_cast<float>(input[i]), scale);
} }
} }
} // namespace vllm } // namespace vllm

View File

@ -144,6 +144,9 @@ void rms_norm_dynamic_per_token_quant(
torch::Tensor& scales, // [num_tokens] torch::Tensor& scales, // [num_tokens]
double const var_epsilon, // Variance epsilon used in norm calculation double const var_epsilon, // Variance epsilon used in norm calculation
std::optional<at::Tensor> scale_ub, std::optional<at::Tensor> residual) { std::optional<at::Tensor> scale_ub, std::optional<at::Tensor> residual) {
static c10::ScalarType kFp8Type = is_fp8_ocp()
? c10::ScalarType::Float8_e4m3fn
: c10::ScalarType::Float8_e4m3fnuz;
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8); TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
TORCH_CHECK(out.is_contiguous() && input.is_contiguous()); TORCH_CHECK(out.is_contiguous() && input.is_contiguous());

View File

@ -31,9 +31,11 @@ static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) {
#endif #endif
} }
static __device__ __forceinline__ FP8_TYPE float_to_fp8(float const x) { template <typename fp8_type>
float const r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); static __device__ __forceinline__ fp8_type float_to_fp8(float const x) {
return static_cast<FP8_TYPE>(r); float const r = fmax(-fp8_e4m3_adjusted_max_v<fp8_type>,
fmin(x, fp8_e4m3_adjusted_max_v<fp8_type>));
return static_cast<fp8_type>(r);
} }
template <typename quant_type_t, bool is_scale_inverted, typename enable = void> template <typename quant_type_t, bool is_scale_inverted, typename enable = void>
@ -54,15 +56,16 @@ struct ScaledQuant<
}; };
template <typename quant_type_t, bool is_scale_inverted> template <typename quant_type_t, bool is_scale_inverted>
struct ScaledQuant< struct ScaledQuant<quant_type_t, is_scale_inverted,
quant_type_t, is_scale_inverted, typename std::enable_if_t<
typename std::enable_if_t<std::is_same_v<quant_type_t, FP8_TYPE>>> { std::is_same_v<quant_type_t, c10::Float8_e4m3fn> ||
std::is_same_v<quant_type_t, c10::Float8_e4m3fnuz>>> {
static __device__ __forceinline__ quant_type_t quant_fn(float const x, static __device__ __forceinline__ quant_type_t quant_fn(float const x,
float const scale) { float const scale) {
if constexpr (is_scale_inverted) { if constexpr (is_scale_inverted) {
return float_to_fp8(x * scale); return float_to_fp8<quant_type_t>(x * scale);
} else { } else {
return float_to_fp8(x / scale); return float_to_fp8<quant_type_t>(x / scale);
} }
} }
}; };

View File

@ -4,7 +4,6 @@
*/ */
// Include both AMD and NVIDIA fp8 types to avoid circular import // Include both AMD and NVIDIA fp8 types to avoid circular import
// TODO(luka/varun) use FP8_TYPE instead after refactoring
#include <c10/util/Float8_e4m3fnuz.h> #include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e4m3fn.h> #include <c10/util/Float8_e4m3fn.h>

View File

@ -9,8 +9,7 @@ from vllm.platforms import current_platform
# Using the default value (240.0) from pytorch will cause accuracy # Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for rocm. # issue on dynamic quantization models. Here use 224.0 for rocm.
ROCM_FP8_MAX = 224.0 ROCM_FP8_MAX = 224.0
FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm() \ FP8_DTYPE = current_platform.fp8_dtype()
else torch.float8_e4m3fn
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:

View File

@ -32,11 +32,8 @@ def scaled_mm_torch(a: torch.Tensor,
def get_8bit_types(): def get_8bit_types():
types = [torch.int8] types = [torch.int8]
supports_fp8 = current_platform.has_device_capability(89) if current_platform.supports_fp8():
if current_platform.is_rocm() and supports_fp8: types.append(current_platform.fp8_dtype())
types.append(torch.float8_e4m3fnuz)
elif current_platform.is_cuda() and supports_fp8:
types.append(torch.float8_e4m3fn)
return types return types

View File

@ -103,8 +103,7 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
assert attn._v_scale == 1.0 assert attn._v_scale == 1.0
if current_platform.is_cuda(): if current_platform.is_cuda():
if current_platform.has_device_capability( if current_platform.supports_fp8() and not force_marlin:
89) and not force_marlin:
# For GPUs with hardware support, we keep weights in fp8 # For GPUs with hardware support, we keep weights in fp8
assert fc1.weight.dtype == torch.float8_e4m3fn assert fc1.weight.dtype == torch.float8_e4m3fn
else: else:
@ -112,11 +111,9 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
# for weight-only quantization using Marlin kernels # for weight-only quantization using Marlin kernels
assert fc1.weight.dtype == torch.int32 assert fc1.weight.dtype == torch.int32
elif current_platform.is_rocm(): elif current_platform.is_rocm():
# Only MI300 and above support quantization='fp8' if current_platform.supports_fp8() and not force_marlin:
if current_platform.has_device_capability(
94) and not force_marlin:
# For GPUs with hardware support, we keep weights in fp8 # For GPUs with hardware support, we keep weights in fp8
assert fc1.weight.dtype == torch.float8_e4m3fnuz assert fc1.weight.dtype == current_platform.fp8_dtype()
else: # unsupported ROCm platform else: # unsupported ROCm platform
pytest.skip( pytest.skip(
"Skip `test_load_fp16_model`. " "Skip `test_load_fp16_model`. "

View File

@ -478,16 +478,16 @@ def cutlass_scaled_mm(a: torch.Tensor,
out_dtype: torch.dtype, out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
""" """
`cutlass_scaled_mm` implements a fused version of `cutlass_scaled_mm` implements a fused version of
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)` `output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
where scale_a * a and scale_b * b are implemented using numpy-style where scale_a * a and scale_b * b are implemented using numpy-style
broadcasting. broadcasting.
In order to support blockwise scaling like found in DeepSeek V3 we also In order to support blockwise scaling like found in DeepSeek V3 we also
support extended "group" broadcast rules. We extend the numpy-style support extended "group" broadcast rules. We extend the numpy-style
broadcasting rules with the following rule: broadcasting rules with the following rule:
"if the extent of a dimension in the source shape is between 1 and "if the extent of a dimension in the source shape is between 1 and
corresponding extent in the target shape we repeat each element along corresponding extent in the target shape we repeat each element along
that dimension src_shape[dim] // target_shape[dim] times consecutively" that dimension src_shape[dim] // target_shape[dim] times consecutively"
example if we have: example if we have:
a = [[1, 2], and target_shape = (2, 4) a = [[1, 2], and target_shape = (2, 4)
@ -564,7 +564,7 @@ def cutlass_sparse_compress(a: torch.Tensor) \
with Cutlass sparse kernels. with Cutlass sparse kernels.
Args: Args:
a (torch.Tensor): a (torch.Tensor):
The input tensor to be compressed. Must have one of the following data types: The input tensor to be compressed. Must have one of the following data types:
- `torch.int8` - `torch.int8`
- `torch.float8_e4m3fn` - `torch.float8_e4m3fn`
@ -572,7 +572,7 @@ def cutlass_sparse_compress(a: torch.Tensor) \
- `torch.float16` - `torch.float16`
Returns: Returns:
tuple[torch.Tensor, torch.Tensor]: tuple[torch.Tensor, torch.Tensor]:
A tuple containing: A tuple containing:
- `a_nzs` (torch.Tensor): A tensor containing non-zero elements of `a`. - `a_nzs` (torch.Tensor): A tensor containing non-zero elements of `a`.
- `a_meta` (torch.Tensor): A tensor containing metadata for the sparse representation. - `a_meta` (torch.Tensor): A tensor containing metadata for the sparse representation.
@ -875,9 +875,8 @@ def scaled_fp8_quant(
# This code assumes batch_dim and num_tokens are flattened # This code assumes batch_dim and num_tokens are flattened
assert (input.ndim == 2) assert (input.ndim == 2)
shape: Union[tuple[int, int], torch.Size] = input.shape shape: Union[tuple[int, int], torch.Size] = input.shape
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype: torch.dtype = torch.float8_e4m3fnuz \ out_dtype: torch.dtype = current_platform.fp8_dtype()
if current_platform.is_rocm() else torch.float8_e4m3fn
if num_token_padding: if num_token_padding:
shape = (max(num_token_padding, input.shape[0]), shape[1]) shape = (max(num_token_padding, input.shape[0]), shape[1])
output = torch.empty(shape, device=input.device, dtype=out_dtype) output = torch.empty(shape, device=input.device, dtype=out_dtype)
@ -908,7 +907,7 @@ def allspark_repack_weight(
has_zp: bool = False has_zp: bool = False
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Rearrange qweight, scale, and zero_point(if asymmetric) to n32k16 format Rearrange qweight, scale, and zero_point(if asymmetric) to n32k16 format
for Ampere W8A16 Fused Gemm kernel for Ampere W8A16 Fused Gemm kernel
Args: Args:
@ -917,10 +916,10 @@ def allspark_repack_weight(
zero_point: fp16/bf16 weight zero_point tensor, 1 x n format. zero_point: fp16/bf16 weight zero_point tensor, 1 x n format.
Must be provided for asymmetric quantization. Must be provided for asymmetric quantization.
has_zp: if use symmetric quantization, has_zp = False. has_zp: if use symmetric quantization, has_zp = False.
if use asymmetric quantization, has_zp = True. if use asymmetric quantization, has_zp = True.
Returns: Returns:
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] :
rearranged weight, scale, and optionally zero_point. rearranged weight, scale, and optionally zero_point.
""" """
K = qweight.shape[0] K = qweight.shape[0]

View File

@ -226,7 +226,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW8A8Fp8) CompressedTensorsW8A8Fp8)
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
Fp8LinearGenericOp, current_platform_fp8_dtype, is_fp8) Fp8LinearGenericOp, is_fp8)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
scaled_quantize) scaled_quantize)
from vllm.model_executor.layers.rotary_embedding import ( from vllm.model_executor.layers.rotary_embedding import (
@ -1238,7 +1238,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
W_Q_UK, W_Q_UK_scales = scaled_quantize( W_Q_UK, W_Q_UK_scales = scaled_quantize(
W_Q_UK, W_Q_UK,
self.reqaunt_weight_group_shape, self.reqaunt_weight_group_shape,
quant_dtype=current_platform_fp8_dtype) quant_dtype=current_platform.fp8_dtype())
# For FP8 save the transpose so we can use # For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly # `apply_w8a8_block_fp8_linear` directly
self.W_Q_UK = W_Q_UK.T.contiguous() self.W_Q_UK = W_Q_UK.T.contiguous()
@ -1255,7 +1255,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
W_UV_O, W_UV_O_scales = scaled_quantize( W_UV_O, W_UV_O_scales = scaled_quantize(
W_UV_O, W_UV_O,
self.reqaunt_weight_group_shape, self.reqaunt_weight_group_shape,
quant_dtype=current_platform_fp8_dtype) quant_dtype=current_platform.fp8_dtype())
# For FP8 save the transpose so we can use # For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly # `apply_w8a8_block_fp8_linear` directly
self.W_UV_O = W_UV_O.T.contiguous() self.W_UV_O = W_UV_O.T.contiguous()

View File

@ -158,8 +158,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_input_scale = torch.nn.Parameter( layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False) layer.w2_input_scale.max(), requires_grad=False)
# If rocm, normalize the weights and scales to e4m3fnuz if current_platform.is_fp8_fnuz():
if current_platform.is_rocm():
# Normalize the weights and scales # Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = \ w13_weight, w13_weight_scale, w13_input_scale = \
normalize_e4m3fn_to_e4m3fnuz( normalize_e4m3fn_to_e4m3fnuz(

View File

@ -42,7 +42,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
logical_widths=layer.logical_widths, logical_widths=layer.logical_widths,
) )
if current_platform.is_rocm(): if current_platform.is_fp8_fnuz():
input_scale = getattr(layer, 'input_scale', None) input_scale = getattr(layer, 'input_scale', None)
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
@ -60,7 +60,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
elif self.strategy == QuantizationStrategy.CHANNEL: elif self.strategy == QuantizationStrategy.CHANNEL:
weight = layer.weight weight = layer.weight
if current_platform.is_rocm(): if current_platform.is_fp8_fnuz():
input_scale = getattr(layer, 'input_scale', None) input_scale = getattr(layer, 'input_scale', None)
weight, weight_scale, input_scale = \ weight, weight_scale, input_scale = \

View File

@ -127,7 +127,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
weight = layer.weight weight = layer.weight
if current_platform.is_rocm(): if current_platform.is_fp8_fnuz():
weight, weight_scale, input_scale = \ weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz( normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight=weight,

View File

@ -270,7 +270,7 @@ class Fp8LinearMethod(LinearMethodBase):
# TODO(rob): refactor block quant into separate class. # TODO(rob): refactor block quant into separate class.
if self.block_quant: if self.block_quant:
assert self.quant_config.activation_scheme == "dynamic" assert self.quant_config.activation_scheme == "dynamic"
if current_platform.is_rocm(): if current_platform.is_fp8_fnuz():
weight, weight_scale_inv, _ = \ weight, weight_scale_inv, _ = \
normalize_e4m3fn_to_e4m3fnuz( normalize_e4m3fn_to_e4m3fnuz(
weight=layer.weight, weight=layer.weight,
@ -327,8 +327,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight = layer.weight weight = layer.weight
weight_scale = layer.weight_scale weight_scale = layer.weight_scale
# If rocm, use float8_e4m3fnuz. if current_platform.is_fp8_fnuz():
if current_platform.is_rocm():
weight, weight_scale, input_scale = \ weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz( normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight=weight,
@ -533,7 +532,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# TODO (rob): refactor block quant into separate class. # TODO (rob): refactor block quant into separate class.
if self.block_quant: if self.block_quant:
assert self.quant_config.activation_scheme == "dynamic" assert self.quant_config.activation_scheme == "dynamic"
if current_platform.is_rocm(): if current_platform.is_fp8_fnuz():
w13_weight, w13_weight_scale_inv, w13_input_scale = \ w13_weight, w13_weight_scale_inv, w13_input_scale = \
normalize_e4m3fn_to_e4m3fnuz( normalize_e4m3fn_to_e4m3fnuz(
layer.w13_weight, layer.w13_weight_scale_inv, layer.w13_weight, layer.w13_weight_scale_inv,
@ -559,9 +558,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# If checkpoint is fp16, quantize in place. # If checkpoint is fp16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized: if not self.quant_config.is_checkpoint_fp8_serialized:
# If rocm, use float8_e4m3fnuz as dtype fp8_dtype = current_platform.fp8_dtype()
fp8_dtype = torch.float8_e4m3fnuz \
if current_platform.is_rocm() else torch.float8_e4m3fn
w13_weight = torch.empty_like(layer.w13_weight.data, w13_weight = torch.empty_like(layer.w13_weight.data,
dtype=fp8_dtype) dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
@ -608,8 +605,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w13_input_scale.max(), requires_grad=False) layer.w13_input_scale.max(), requires_grad=False)
layer.w2_input_scale = torch.nn.Parameter( layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False) layer.w2_input_scale.max(), requires_grad=False)
# If rocm, normalize the weights and scales to e4m3fnuz if current_platform.is_fp8_fnuz():
if current_platform.is_rocm():
# Normalize the weights and scales # Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = \ w13_weight, w13_weight_scale, w13_input_scale = \
normalize_e4m3fn_to_e4m3fnuz( normalize_e4m3fn_to_e4m3fnuz(

View File

@ -142,8 +142,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
layer.w2_input_scale = torch.nn.Parameter( layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False) layer.w2_input_scale.max(), requires_grad=False)
# If rocm, normalize the weights and scales to e4m3fnuz if current_platform.is_fp8_fnuz():
if current_platform.is_rocm():
# Normalize the weights and scales # Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = \ w13_weight, w13_weight_scale, w13_input_scale = \
normalize_e4m3fn_to_e4m3fnuz( normalize_e4m3fn_to_e4m3fnuz(

View File

@ -39,7 +39,7 @@ class QuarkW8A8Fp8(QuarkScheme):
logical_widths=layer.logical_widths, logical_widths=layer.logical_widths,
) )
if current_platform.is_rocm(): if current_platform.is_fp8_fnuz():
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight=weight,
weight_scale=max_w_scale, weight_scale=max_w_scale,
@ -55,7 +55,7 @@ class QuarkW8A8Fp8(QuarkScheme):
elif self.qscheme == "per_channel": elif self.qscheme == "per_channel":
weight = layer.weight weight = layer.weight
if current_platform.is_rocm(): if current_platform.is_fp8_fnuz():
weight, weight_scale, input_scale = \ weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz( normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight=weight,

View File

@ -22,10 +22,6 @@ from vllm.utils import direct_register_custom_op
logger = init_logger(__name__) logger = init_logger(__name__)
current_platform_fp8_dtype = (torch.float8_e4m3fnuz
if current_platform.is_rocm() else
torch.float8_e4m3fn)
def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool: def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
@ -165,9 +161,7 @@ def input_to_float8(
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""This function quantizes input values to float8 values " """This function quantizes input values to float8 values "
"with tensor-wise quantization.""" "with tensor-wise quantization."""
if dtype is None: dtype = current_platform.fp8_dtype() if dtype is None else dtype
dtype = (torch.float8_e4m3fnuz
if current_platform.is_rocm() else torch.float8_e4m3fn)
finfo = torch.finfo(dtype) finfo = torch.finfo(dtype)
min_val, max_val = x.aminmax() min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
@ -311,9 +305,7 @@ def per_token_group_quant_fp8(
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization. scaling factor for quantization.
""" """
if dtype is None: dtype = current_platform.fp8_dtype() if dtype is None else dtype
dtype = (torch.float8_e4m3fnuz
if current_platform.is_rocm() else torch.float8_e4m3fn)
assert (x.shape[-1] % group_size == 0), ( assert (x.shape[-1] % group_size == 0), (
f"the last dimension of `x` {x.shape[-1]} must be divisible " f"the last dimension of `x` {x.shape[-1]} must be divisible "
f"by `group_size` {group_size}") f"by `group_size` {group_size}")

View File

@ -293,6 +293,10 @@ class CudaPlatformBase(Platform):
def get_device_communicator_cls(cls) -> str: def get_device_communicator_cls(cls) -> str:
return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
@classmethod
def supports_fp8(cls) -> bool:
return cls.has_device_capability(89)
# NVML utils # NVML utils
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`, # Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,

View File

@ -330,6 +330,36 @@ class Platform:
""" """
return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa
@classmethod
def supports_fp8(cls) -> bool:
"""
Returns whether the current platform supports FP8 types.
"""
return False
@classmethod
def is_fp8_fnuz(cls) -> bool:
"""
Returns whether the preferred FP8 type is FNUZ on the current platform.
There are two representations of FP8, OCP FP8 and FNUZ FP8.
The OCP specification can be found at https://tinyurl.com/b7jvwpft.
The FNUZ specification can be found at https://tinyurl.com/5n6hwwu5.
AMD's MI300 and MI325 have native hardware support for FNUZ. All other
hardware has converged on the OCP FP8 standard.
"""
return False
@classmethod
def fp8_dtype(cls) -> torch.dtype:
"""
Returns the preferred FP8 type on the current platform.
See the documentation for is_fp8_fnuz for details.
"""
return torch.float8_e4m3fn
@classmethod @classmethod
def use_all_gather(cls) -> bool: def use_all_gather(cls) -> bool:
""" """

View File

@ -231,3 +231,20 @@ class RocmPlatform(Platform):
@classmethod @classmethod
def get_device_communicator_cls(cls) -> str: def get_device_communicator_cls(cls) -> str:
return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
@classmethod
def supports_fp8(cls) -> bool:
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
return any(gfx in gcn_arch for gfx in ['gfx94', 'gfx95', 'gfx12'])
@classmethod
def is_fp8_fnuz(cls) -> bool:
# only device 0 is checked, this assumes MI300 platforms are homogeneous
return 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName
@classmethod
def fp8_dtype(cls) -> torch.dtype:
if cls.is_fp8_fnuz():
return torch.float8_e4m3fnuz
else:
return torch.float8_e4m3fn

View File

@ -219,7 +219,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW8A8Fp8) CompressedTensorsW8A8Fp8)
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
Fp8LinearGenericOp, current_platform_fp8_dtype, is_fp8) Fp8LinearGenericOp, is_fp8)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
scaled_quantize) scaled_quantize)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
@ -826,7 +826,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
W_Q_UK, W_Q_UK_scales = scaled_quantize( W_Q_UK, W_Q_UK_scales = scaled_quantize(
W_Q_UK, W_Q_UK,
self.reqaunt_weight_group_shape, self.reqaunt_weight_group_shape,
quant_dtype=current_platform_fp8_dtype) quant_dtype=current_platform.fp8_dtype())
# For FP8 save the transpose so we can use # For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly # `apply_w8a8_block_fp8_linear` directly
self.W_Q_UK = W_Q_UK.T.contiguous() self.W_Q_UK = W_Q_UK.T.contiguous()
@ -843,7 +843,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
W_UV_O, W_UV_O_scales = scaled_quantize( W_UV_O, W_UV_O_scales = scaled_quantize(
W_UV_O, W_UV_O,
self.reqaunt_weight_group_shape, self.reqaunt_weight_group_shape,
quant_dtype=current_platform_fp8_dtype) quant_dtype=current_platform.fp8_dtype())
# For FP8 save the transpose so we can use # For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly # `apply_w8a8_block_fp8_linear` directly
self.W_UV_O = W_UV_O.T.contiguous() self.W_UV_O = W_UV_O.T.contiguous()