[torch.compile] Fuse RMSNorm with quant (#9138)

Signed-off-by: luka <luka@neuralmagic.com>
Co-authored-by: youkaichao <youkaichao@126.com>
This commit is contained in:
Luka Govedič 2024-11-08 16:20:08 -05:00 committed by GitHub
parent e1b5a82179
commit 4f93dfe952
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 1335 additions and 368 deletions

View File

@ -191,6 +191,7 @@ set(VLLM_EXT_SRC
"csrc/pos_encoding_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/layernorm_quant_kernels.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"

View File

@ -1,21 +1,13 @@
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include "type_convert.cuh"
#include "dispatch_utils.h"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include "dispatch_utils.h"
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
using __nv_bfloat16 = __hip_bfloat16;
using __nv_bfloat162 = __hip_bfloat162;
#endif
namespace vllm {
@ -51,155 +43,6 @@ __global__ void rms_norm_kernel(
}
}
/* Converter structs for the conversion from torch types to HIP/CUDA types,
and the associated type conversions within HIP/CUDA. These helpers need
to be implemented for now because the relevant type conversion
operators/constructors are not consistently implemented by HIP/CUDA, so
a generic conversion via type casts cannot be implemented.
Each struct should have the member static constexpr bool `exists`:
If false, the optimized kernel is not used for the corresponding torch type.
If true, the struct should be fully defined as shown in the examples below.
*/
template <typename torch_type>
struct _typeConvert {
static constexpr bool exists = false;
};
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
// CUDA < 12.0 runs into issues with packed type conversion
template <>
struct _typeConvert<c10::Half> {
static constexpr bool exists = true;
using hip_type = __half;
using packed_hip_type = __half2;
__device__ static inline float convert(hip_type x) { return __half2float(x); }
__device__ static inline float2 convert(packed_hip_type x) {
return __half22float2(x);
}
__device__ static inline hip_type convert(float x) {
return __float2half_rn(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
return __float22half2_rn(x);
}
};
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// CUDA_ARCH < 800 does not have BF16 support
// TODO: Add in ROCm support once public headers handle bf16 maturely
template <>
struct _typeConvert<c10::BFloat16> {
static constexpr bool exists = true;
using hip_type = __nv_bfloat16;
using packed_hip_type = __nv_bfloat162;
__device__ static inline float convert(hip_type x) {
return __bfloat162float(x);
}
__device__ static inline float2 convert(packed_hip_type x) {
return __bfloat1622float2(x);
}
__device__ static inline hip_type convert(float x) {
return __float2bfloat16(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
return __float22bfloat162_rn(x);
}
};
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
// 12000))
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
for appropriate specializations of fused_add_rms_norm_kernel.
Only functions that are necessary in that kernel are implemented.
Alignment to 16 bytes is required to use 128-bit global memory ops.
*/
template <typename scalar_t, int width>
struct alignas(16) _f16Vec {
/* Not theoretically necessary that width is a power of 2 but should
almost always be the case for optimization purposes */
static_assert(width > 0 && (width & (width - 1)) == 0,
"Width is not a positive power of 2!");
using Converter = _typeConvert<scalar_t>;
using T1 = typename Converter::hip_type;
using T2 = typename Converter::packed_hip_type;
T1 data[width];
__device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i + 1]};
temp += T2{other.data[i], other.data[i + 1]};
data[i] = temp.x;
data[i + 1] = temp.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) data[i] += other.data[i];
}
return *this;
}
__device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i + 1]};
temp *= T2{other.data[i], other.data[i + 1]};
data[i] = temp.x;
data[i + 1] = temp.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) data[i] *= other.data[i];
}
return *this;
}
__device__ _f16Vec& operator*=(const float scale) {
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 temp_f = Converter::convert(T2{data[i], data[i + 1]});
temp_f.x *= scale;
temp_f.y *= scale;
T2 temp = Converter::convert(temp_f);
data[i] = temp.x;
data[i + 1] = temp.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) {
float temp = Converter::convert(data[i]) * scale;
data[i] = Converter::convert(temp);
}
}
return *this;
}
__device__ float sum_squares() const {
float result = 0.0f;
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 z = Converter::convert(T2{data[i], data[i + 1]});
result += z.x * z.x + z.y * z.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) {
float x = Converter::convert(data[i]);
result += x * x;
}
}
return result;
}
};
/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the

View File

@ -0,0 +1,234 @@
/*
* This file contains the CUDA kernels for the fused quantized layernorm.
* The kernels correspond to the kernels in layernorm_kernels.cu, except they
* also produce quantized output directly.
* Currently, only static fp8 quantization is supported.
*/
#include "type_convert.cuh"
#include "quantization/fp8/common.cuh"
#include "dispatch_utils.h"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
#ifndef USE_ROCM
#include <cub/cub.cuh>
#else
#include <hipcub/hipcub.hpp>
#endif
namespace vllm {
// TODO(woosuk): Further optimize this kernel.
template <typename scalar_t>
__global__ void rms_norm_static_fp8_quant_kernel(
FP8_TYPE* __restrict__ out, // [..., hidden_size]
const scalar_t* __restrict__ input, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float* __restrict__ scale, // [1]
const float epsilon, const int num_tokens, const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float)input[blockIdx.x * hidden_size + idx];
variance += x * x;
}
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
// invert scale to avoid division
float const scale_inv = 1.0f / *scale;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)input[blockIdx.x * hidden_size + idx];
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
out[blockIdx.x * hidden_size + idx] =
scaled_fp8_conversion<true>(out_norm, scale_inv);
}
}
/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
memory latency bottleneck. */
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
fused_add_rms_norm_static_fp8_quant_kernel(
FP8_TYPE* __restrict__ out, // [..., hidden_size]
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float* __restrict__ scale, // [1]
const float epsilon, const int num_tokens, const int hidden_size) {
// Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
const int vec_hidden_size = hidden_size / width;
__shared__ float s_variance;
float variance = 0.0f;
/* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */
auto* __restrict__ input_v =
reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
auto* __restrict__ residual_v =
reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
auto* __restrict__ weight_v =
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16Vec<scalar_t, width> temp = input_v[id];
temp += residual_v[id];
variance += temp.sum_squares();
residual_v[id] = temp;
}
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
// invert scale to avoid division
float const scale_inv = 1.0f / *scale;
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
int id = blockIdx.x * vec_hidden_size + idx;
_f16Vec<scalar_t, width> temp = residual_v[id];
temp *= s_variance;
temp *= weight_v[idx];
#pragma unroll
for (int i = 0; i < width; ++i) {
out[id * width + i] =
scaled_fp8_conversion<true>(float(temp.data[i]), scale_inv);
}
}
}
/* Generic fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations.
*/
template <typename scalar_t, int width>
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
fused_add_rms_norm_static_fp8_quant_kernel(
FP8_TYPE* __restrict__ out, // [..., hidden_size]
scalar_t* __restrict__ input, // [..., hidden_size]
scalar_t* __restrict__ residual, // [..., hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float* __restrict__ scale, // [1]
const float epsilon, const int num_tokens, const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
scalar_t z = input[blockIdx.x * hidden_size + idx];
z += residual[blockIdx.x * hidden_size + idx];
float x = (float)z;
variance += x * x;
residual[blockIdx.x * hidden_size + idx] = z;
}
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();
// invert scale to avoid division
float const scale_inv = 1.0f / *scale;
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float)residual[blockIdx.x * hidden_size + idx];
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
out[blockIdx.x * hidden_size + idx] =
scaled_fp8_conversion<true>(out_norm, scale_inv);
}
}
} // namespace vllm
void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
torch::Tensor& scale, // [1]
double epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
vllm::rms_norm_static_fp8_quant_kernel<scalar_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), epsilon,
num_tokens, hidden_size);
});
}
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, width> \
<<<grid, block, 0, stream>>>( \
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \
scale.data_ptr<float>(), epsilon, num_tokens, hidden_size); \
});
void fused_add_rms_norm_static_fp8_quant(
torch::Tensor& out, // [..., hidden_size],
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& residual, // [..., hidden_size]
torch::Tensor& weight, // [hidden_size]
torch::Tensor& scale, // [1]
double epsilon) {
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
for increased block occupancy on CUs and better latency
hiding on global mem ops. */
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
dim3 block(std::min(hidden_size, max_block_size));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
/*If the tensor types are FP16/BF16, try to use the optimized kernel
with packed + vectorized ops.
Max optimization is achieved with a width-8 vector of FP16/BF16s
since we can load at most 128 bits at once in a global memory op.
However, this requires each tensor's data to be aligned to 16
bytes.
*/
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
bool ptrs_are_aligned =
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
if (ptrs_are_aligned && hidden_size % 8 == 0) {
LAUNCH_FUSED_ADD_RMS_NORM(8);
} else {
LAUNCH_FUSED_ADD_RMS_NORM(0);
}
}

View File

@ -56,6 +56,16 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon);
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch::Tensor& weight, torch::Tensor& scale,
double epsilon);
void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& residual,
torch::Tensor& weight,
torch::Tensor& scale, double epsilon);
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox);

View File

@ -1,185 +1,16 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <cmath>
#include "cuda_compat.h"
#include "common.cuh"
#include "dispatch_utils.h"
#include <c10/cuda/CUDAGuard.h>
#ifndef USE_ROCM
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#else
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
#endif
#ifndef USE_ROCM
using FP8_TYPE = c10::Float8_e4m3fn;
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
std::numeric_limits<FP8_TYPE>::max();
#else
#include "amd/hip_float8.h"
using FP8_TYPE = c10::Float8_e4m3fnuz;
// Using the default max value from pytorch (240.0) will cause accuracy
// issue when running dynamic quantization. Here use 224.0f for rocm.
constexpr auto FP8_E4M3_MAX = 224.0f;
#endif
namespace vllm {
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
float old;
old = (value >= 0)
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
: __uint_as_float(
atomicMin((unsigned int*)addr, __float_as_uint(value)));
return old;
}
template <bool is_scale_inverted>
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
float const scale) {
float x = 0.0f;
if constexpr (is_scale_inverted) {
x = val * scale;
} else {
x = val / scale;
}
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
#ifndef USE_ROCM
return static_cast<c10::Float8_e4m3fn>(r);
#else
// Use hardware cvt instruction for fp8 on rocm
return c10::Float8_e4m3fnuz(hip_fp8(r).data,
c10::Float8_e4m3fnuz::from_bits());
#endif
}
// Compute the absolute maximum m of the input tensor and store
// m / float8_e4m3::max() in *scale. Each thread block performs a
// reduction tree and the memory in scale is atomically updated.
// So to get the right answer, *scale needs to be initialized to
// a value <= 0.0 and we need to wait for all thread blocks to
// finish before consuming *scale.
template <typename scalar_t>
__global__ void segmented_max_reduction(float* __restrict__ scale,
const scalar_t* __restrict__ input,
int64_t num_elems) {
__shared__ float cache[1024];
int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
// First store maximum for all values processes by
// the current thread in cache[threadIdx.x]
scalar_t tmp = 0.0;
while (i < num_elems) {
float x = static_cast<float>(input[i]);
tmp = max(tmp, fabs(x));
i += blockDim.x * gridDim.x;
}
cache[threadIdx.x] = tmp;
__syncthreads();
// Now perform parallel reduction within the thread block
int ib = blockDim.x / 2;
while (ib != 0) {
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
cache[threadIdx.x] = cache[threadIdx.x + ib];
}
__syncthreads();
ib /= 2;
}
// Finally, since cache[0] contains the maximum for this thread block,
// atomically write the max to the target location
if (threadIdx.x == 0) {
atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX);
}
}
template <typename scalar_t>
struct __align__(8) vec4_t {
scalar_t x;
scalar_t y;
scalar_t z;
scalar_t w;
};
typedef struct __align__(4) {
FP8_TYPE x;
FP8_TYPE y;
FP8_TYPE z;
FP8_TYPE w;
}
float8x4_t;
template <typename scalar_t>
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
int64_t const num_elems, int const tid,
int const step) {
// Vectorized input/output to better utilize memory bandwidth.
vec4_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec4_t<scalar_t> const*>(input);
int64_t const num_vec_elems = num_elems >> 2;
float absmax_val = 0.0f;
#pragma unroll 4
for (int64_t i = tid; i < num_vec_elems; i += step) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
absmax_val = max(absmax_val, fabs(in_vec.x));
absmax_val = max(absmax_val, fabs(in_vec.y));
absmax_val = max(absmax_val, fabs(in_vec.z));
absmax_val = max(absmax_val, fabs(in_vec.w));
}
// Handle the remaining elements if num_elems is not divisible by 4
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
absmax_val = max(absmax_val, fabs(input[i]));
}
return absmax_val;
}
template <typename scalar_t, bool is_scale_inverted>
__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
scalar_t const* __restrict__ input,
float const scale,
int64_t const num_elems,
int const tid, int const step) {
// Vectorized input/output to better utilize memory bandwidth.
vec4_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec4_t<scalar_t> const*>(input);
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
int64_t const num_vec_elems = num_elems >> 2;
#pragma unroll 4
for (int64_t i = tid; i < num_vec_elems; i += step) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
float8x4_t out_vec;
out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.x), scale);
out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.y), scale);
out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.z), scale);
out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.w), scale);
vectorized_out[i] = out_vec;
}
// Handle the remaining elements if num_elems is not divisible by 4
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
out[i] = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(input[i]), scale);
}
}
template <typename scalar_t>
__global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
const scalar_t* __restrict__ input,

View File

@ -0,0 +1,172 @@
#pragma once
#include <cmath>
#ifndef USE_ROCM
#include <c10/util/Float8_e4m3fn.h>
using FP8_TYPE = c10::Float8_e4m3fn;
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
std::numeric_limits<FP8_TYPE>::max();
#else
#include <c10/util/Float8_e4m3fnuz.h>
#include "amd/hip_float8.h"
using FP8_TYPE = c10::Float8_e4m3fnuz;
// Using the default max value from pytorch (240.0) will cause accuracy
// issue when running dynamic quantization. Here use 224.0f for rocm.
constexpr auto FP8_E4M3_MAX = 224.0f;
#endif
namespace vllm {
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
float old;
old = (value >= 0)
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
: __uint_as_float(
atomicMin((unsigned int*)addr, __float_as_uint(value)));
return old;
}
template <bool is_scale_inverted>
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
float const scale) {
float x = 0.0f;
if constexpr (is_scale_inverted) {
x = val * scale;
} else {
x = val / scale;
}
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
#ifndef USE_ROCM
return static_cast<c10::Float8_e4m3fn>(r);
#else
// Use hardware cvt instruction for fp8 on rocm
return c10::Float8_e4m3fnuz(hip_fp8(r).data,
c10::Float8_e4m3fnuz::from_bits());
#endif
}
// Compute the absolute maximum m of the input tensor and store
// m / float8_e4m3::max() in *scale. Each thread block performs a
// reduction tree and the memory in scale is atomically updated.
// So to get the right answer, *scale needs to be initialized to
// a value <= 0.0 and we need to wait for all thread blocks to
// finish before consuming *scale.
template <typename scalar_t>
__global__ void segmented_max_reduction(float* __restrict__ scale,
const scalar_t* __restrict__ input,
int64_t num_elems) {
__shared__ float cache[1024];
int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
// First store maximum for all values processes by
// the current thread in cache[threadIdx.x]
scalar_t tmp = 0.0;
while (i < num_elems) {
float x = static_cast<float>(input[i]);
tmp = max(tmp, fabs(x));
i += blockDim.x * gridDim.x;
}
cache[threadIdx.x] = tmp;
__syncthreads();
// Now perform parallel reduction within the thread block
int ib = blockDim.x / 2;
while (ib != 0) {
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
cache[threadIdx.x] = cache[threadIdx.x + ib];
}
__syncthreads();
ib /= 2;
}
// Finally, since cache[0] contains the maximum for this thread block,
// atomically write the max to the target location
if (threadIdx.x == 0) {
atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX);
}
}
template <typename scalar_t>
struct __align__(8) vec4_t {
scalar_t x;
scalar_t y;
scalar_t z;
scalar_t w;
};
typedef struct __align__(4) {
FP8_TYPE x;
FP8_TYPE y;
FP8_TYPE z;
FP8_TYPE w;
}
float8x4_t;
template <typename scalar_t>
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
int64_t const num_elems, int const tid,
int const step) {
// Vectorized input/output to better utilize memory bandwidth.
vec4_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec4_t<scalar_t> const*>(input);
int64_t const num_vec_elems = num_elems >> 2;
float absmax_val = 0.0f;
#pragma unroll 4
for (int64_t i = tid; i < num_vec_elems; i += step) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
absmax_val = max(absmax_val, fabs(in_vec.x));
absmax_val = max(absmax_val, fabs(in_vec.y));
absmax_val = max(absmax_val, fabs(in_vec.z));
absmax_val = max(absmax_val, fabs(in_vec.w));
}
// Handle the remaining elements if num_elems is not divisible by 4
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
absmax_val = max(absmax_val, fabs(input[i]));
}
return absmax_val;
}
template <typename scalar_t, bool is_scale_inverted>
__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
scalar_t const* __restrict__ input,
float const scale,
int64_t const num_elems,
int const tid, int const step) {
// Vectorized input/output to better utilize memory bandwidth.
vec4_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec4_t<scalar_t> const*>(input);
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
int64_t const num_vec_elems = num_elems >> 2;
#pragma unroll 4
for (int64_t i = tid; i < num_vec_elems; i += step) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
float8x4_t out_vec;
out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.x), scale);
out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.y), scale);
out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.z), scale);
out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.w), scale);
vectorized_out[i] = out_vec;
}
// Handle the remaining elements if num_elems is not divisible by 4
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
out[i] = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(input[i]), scale);
}
}
} // namespace vllm

View File

@ -101,7 +101,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
"rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
"rms_norm(Tensor! result, Tensor input, Tensor weight, float epsilon) -> "
"()");
ops.impl("rms_norm", torch::kCUDA, &rms_norm);
@ -111,6 +111,23 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"float epsilon) -> ()");
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
// Layernorm-quant
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops.def(
"rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, "
"Tensor scale, float epsilon) -> "
"()");
ops.impl("rms_norm_static_fp8_quant", torch::kCUDA,
&rms_norm_static_fp8_quant);
// In-place fused Add and RMS Normalization.
ops.def(
"fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, "
"Tensor! residual, Tensor weight, "
"Tensor scale, float epsilon) -> ()");
ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA,
&fused_add_rms_norm_static_fp8_quant);
// Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(
@ -322,18 +339,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Compute FP8 quantized tensor for given scaling factor.
ops.def(
"static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()");
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
"()");
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
ops.def(
"dynamic_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
"dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
"-> "
"()");
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
ops.def(
"dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, "
"dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
"Tensor! scale, Tensor? scale_ub) -> "
"()");
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
@ -341,13 +360,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// Compute int8 quantized tensor for given scaling factor.
ops.def(
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
"static_scaled_int8_quant(Tensor! result, Tensor input, Tensor scale,"
"Tensor? azp) -> ()");
ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant);
// Compute int8 quantized tensor and scaling factor
ops.def(
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
"dynamic_scaled_int8_quant(Tensor! result, Tensor input, Tensor! scale, "
"Tensor!? azp) -> ()");
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
&dynamic_scaled_int8_quant);

165
csrc/type_convert.cuh Normal file
View File

@ -0,0 +1,165 @@
#pragma once
#include <torch/all.h>
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
using __nv_bfloat16 = __hip_bfloat16;
using __nv_bfloat162 = __hip_bfloat162;
#endif
namespace vllm {
/* Converter structs for the conversion from torch types to HIP/CUDA types,
and the associated type conversions within HIP/CUDA. These helpers need
to be implemented for now because the relevant type conversion
operators/constructors are not consistently implemented by HIP/CUDA, so
a generic conversion via type casts cannot be implemented.
Each struct should have the member static constexpr bool `exists`:
If false, the optimized kernel is not used for the corresponding torch type.
If true, the struct should be fully defined as shown in the examples below.
*/
template <typename torch_type>
struct _typeConvert {
static constexpr bool exists = false;
};
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
// CUDA < 12.0 runs into issues with packed type conversion
template <>
struct _typeConvert<c10::Half> {
static constexpr bool exists = true;
using hip_type = __half;
using packed_hip_type = __half2;
__device__ static inline float convert(hip_type x) { return __half2float(x); }
__device__ static inline float2 convert(packed_hip_type x) {
return __half22float2(x);
}
__device__ static inline hip_type convert(float x) {
return __float2half_rn(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
return __float22half2_rn(x);
}
};
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// CUDA_ARCH < 800 does not have BF16 support
// TODO: Add in ROCm support once public headers handle bf16 maturely
template <>
struct _typeConvert<c10::BFloat16> {
static constexpr bool exists = true;
using hip_type = __nv_bfloat16;
using packed_hip_type = __nv_bfloat162;
__device__ static inline float convert(hip_type x) {
return __bfloat162float(x);
}
__device__ static inline float2 convert(packed_hip_type x) {
return __bfloat1622float2(x);
}
__device__ static inline hip_type convert(float x) {
return __float2bfloat16(x);
}
__device__ static inline packed_hip_type convert(float2 x) {
return __float22bfloat162_rn(x);
}
};
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
// 12000))
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
for appropriate specializations of fused_add_rms_norm_kernel.
Only functions that are necessary in that kernel are implemented.
Alignment to 16 bytes is required to use 128-bit global memory ops.
*/
template <typename scalar_t, int width>
struct alignas(16) _f16Vec {
/* Not theoretically necessary that width is a power of 2 but should
almost always be the case for optimization purposes */
static_assert(width > 0 && (width & (width - 1)) == 0,
"Width is not a positive power of 2!");
using Converter = _typeConvert<scalar_t>;
using T1 = typename Converter::hip_type;
using T2 = typename Converter::packed_hip_type;
T1 data[width];
__device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i + 1]};
temp += T2{other.data[i], other.data[i + 1]};
data[i] = temp.x;
data[i + 1] = temp.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) data[i] += other.data[i];
}
return *this;
}
__device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
T2 temp{data[i], data[i + 1]};
temp *= T2{other.data[i], other.data[i + 1]};
data[i] = temp.x;
data[i + 1] = temp.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) data[i] *= other.data[i];
}
return *this;
}
__device__ _f16Vec& operator*=(const float scale) {
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 temp_f = Converter::convert(T2{data[i], data[i + 1]});
temp_f.x *= scale;
temp_f.y *= scale;
T2 temp = Converter::convert(temp_f);
data[i] = temp.x;
data[i + 1] = temp.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) {
float temp = Converter::convert(data[i]) * scale;
data[i] = Converter::convert(temp);
}
}
return *this;
}
__device__ float sum_squares() const {
float result = 0.0f;
if constexpr (width % 2 == 0) {
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 z = Converter::convert(T2{data[i], data[i + 1]});
result += z.x * z.x + z.y * z.y;
}
} else {
#pragma unroll
for (int i = 0; i < width; ++i) {
float x = Converter::convert(data[i]);
result += x * x;
}
}
return result;
}
};
} // namespace vllm

33
tests/compile/backend.py Normal file
View File

@ -0,0 +1,33 @@
from copy import deepcopy
from typing import Callable
import torch
class TestBackend:
"""
This class provides a simple Inductor backend that can be used for testing.
It takes a list of custom passes and runs them after Inductor's passes.
It also saves the graph before and after the custom passes for inspection.
"""
def __init__(self, *args: Callable[[torch.fx.Graph], None]):
self.custom_passes = args
from torch._inductor import config
self.current_config = config.shallow_copy_dict()
self.current_config['post_grad_custom_post_pass'] = self.post_pass
def __call__(self, graph: torch.fx.GraphModule, example_inputs):
from torch._inductor.compile_fx import compile_fx
return compile_fx(graph,
example_inputs,
config_patches=self.current_config)
def post_pass(self, graph: torch.fx.Graph):
self.graph_pre_pass = deepcopy(graph)
for pass_ in self.custom_passes:
pass_(graph)
self.graph_post_pass = deepcopy(graph)
# assign by reference, will reflect the final state of the graph
self.final_graph = graph

View File

@ -0,0 +1,92 @@
import pytest
import torch
from compressed_tensors.quantization import FP8_DTYPE
import vllm.envs as envs
from vllm.compilation.config import CompilationConfig
from vllm.compilation.fusion import (FusionPass, find_auto_fn,
find_auto_fn_maybe)
from vllm.compilation.reshapes import RedundantReshapesPass
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear)
from .backend import TestBackend
class TestModel(torch.nn.Module):
def __init__(self, hidden_size: int, eps: float, *args, **kwargs):
super().__init__(*args, **kwargs)
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(4)]
self.w = [
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
for _ in range(2)
]
def forward(self, x):
resid = torch.relu(x)
y = self.norm[0](x)
x2 = apply_fp8_linear(y, self.w[0], self.scale[0], self.scale[1])
# make sure resid is used for replacement to work
y2, resid = self.norm[1](x2, resid)
x3 = apply_fp8_linear(y2, self.w[1], self.scale[2], self.scale[3])
y3, resid = self.norm[2](x3, resid) # use resid here
return y3
# Init does pattern registration, which can only happen once
config = CompilationConfig(enable_fusion=True)
reshape_pass = RedundantReshapesPass(config)
fusion_pass = FusionPass.instance(config)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096])
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
reason="Only test on CUDA")
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps):
torch.set_default_device("cuda")
torch.set_default_dtype(torch.float16)
if eps != 1e-5:
pytest.skip("Only test eps=1e-5 for now")
# Reshape pass is needed for the fusion pass to work
backend = TestBackend(reshape_pass, fusion_pass)
model = TestModel(hidden_size, eps)
# First dimension dynamic
x = torch.rand(num_tokens, hidden_size)
torch._dynamo.mark_dynamic(x, 0)
result = model(x)
model2 = torch.compile(model, backend=backend)
result2 = model2(x)
# Check that it gives the same answer
torch.testing.assert_close(result, result2, atol=1e-3, rtol=1e-3)
# Check substitution worked
pre_nodes = backend.graph_pre_pass.nodes
post_nodes = backend.graph_post_pass.nodes
rms_quant = torch.ops._C.rms_norm_static_fp8_quant.default
add_rms_quant = torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
fp8_quant = torch.ops._C.static_scaled_fp8_quant.default
# In pre-nodes, fp8 quant should be present and fused kernels should not
assert find_auto_fn_maybe(pre_nodes, rms_quant) is None
assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None
find_auto_fn(pre_nodes, fp8_quant)
# In post-nodes, fused kernels should be present and fp8 quant should not
find_auto_fn(post_nodes, rms_quant)
find_auto_fn(post_nodes, add_rms_quant)
assert find_auto_fn_maybe(post_nodes, fp8_quant) is None

View File

@ -1,13 +1,14 @@
import pytest
import torch
from tests.kernels.quant_utils import FP8_DTYPE
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.platforms import current_platform
DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192,
HIDDEN_SIZES = [8, 768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192,
8199] # Arbitrary values for testing
ADD_RESIDUAL = [False, True]
SEEDS = [0]
@ -59,3 +60,75 @@ def test_rms_norm(
else:
opcheck(torch.ops._C.rms_norm,
(out, x, layer.weight.data, layer.variance_epsilon))
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_scale", [1.0, 0.01, 10.0])
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_fused_rms_norm_quant(
num_tokens: int,
hidden_size: int,
add_residual: bool,
dtype: torch.dtype,
quant_scale: float,
seed: int,
device: str,
) -> None:
current_platform.seed_everything(seed)
torch.set_default_device(device)
weight = torch.empty(hidden_size, dtype=dtype).normal_(mean=1.0, std=0.1)
scale = 1 / (2 * hidden_size)
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
x *= scale
if add_residual:
residual = torch.randn_like(x) * scale
residual_fused = residual.clone()
else:
residual = residual_fused = None
out_norm = torch.empty_like(x)
out_quant = torch.empty_like(x, dtype=FP8_DTYPE)
out_quant_fused = torch.empty_like(out_quant)
quant_scale_t = torch.tensor(quant_scale, dtype=torch.float32)
if add_residual:
torch.ops._C.fused_add_rms_norm_static_fp8_quant(
out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6)
# Unfused kernel is in-place so it goes second
# Also use a separate clone of x to avoid modifying the input
x_unfused = x.clone()
torch.ops._C.fused_add_rms_norm(x_unfused, residual, weight, 1e-6)
torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused,
quant_scale_t)
torch.cuda.synchronize()
torch.testing.assert_close(residual_fused,
residual,
atol=1e-2,
rtol=1e-2)
opcheck(
torch.ops._C.fused_add_rms_norm_static_fp8_quant,
(out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6))
else:
torch.ops._C.rms_norm_static_fp8_quant(out_quant_fused, x, weight,
quant_scale_t, 1e-6)
torch.ops._C.rms_norm(out_norm, x, weight, 1e-6)
torch.ops._C.static_scaled_fp8_quant(out_quant, out_norm,
quant_scale_t)
opcheck(torch.ops._C.rms_norm_static_fp8_quant,
(out_quant_fused, x, weight, quant_scale_t, 1e-6))
torch.testing.assert_close(out_quant_fused.to(dtype=torch.float32),
out_quant.to(dtype=torch.float32),
atol=1e-3,
rtol=1e-3)

View File

@ -2,7 +2,8 @@ import copy
import dataclasses
import operator
from contextlib import ExitStack
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Tuple,
Union)
from unittest.mock import patch
import torch
@ -10,11 +11,13 @@ import torch.fx as fx
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils import weak_ref_tensors
from vllm.utils import combine_fx_passes, weak_ref_tensors
from .config import CompilationConfig
from .counter import compilation_counter
from .fusion import FusionPass
from .levels import CompilationLevel
from .reshapes import RedundantReshapesPass
logger = init_logger(__name__)
@ -99,28 +102,74 @@ def fix_functionalization(graph: fx.Graph):
user.replace_all_uses_with(replace_node)
nodes_to_remove.append(user)
nodes_to_remove.append(node)
elif (node.args[0] ==
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default):
# manual replace for fused_add_rms_norm_static_fp8_quant
# this is the most effective optimization for llama
# failing to do this will result in many unnecessary copies
kwargs = node.kwargs
result = kwargs['result']
residual = kwargs['residual']
# Create a new call to
# torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
with graph.inserting_before(node):
# just insert the call to the custom op
# NOTE: don't run dead code elimination,
# otherwise this op will be removed
graph.call_function(
torch.ops._C.fused_add_rms_norm_static_fp8_quant.
default,
kwargs=kwargs)
for user in list(node.users):
if user.op == 'call_function' and user.target == operator.getitem: # noqa
# Remove the getitem node
if user.args[1] == 1:
replace_node = result
elif user.args[1] == 2:
replace_node = residual
user.replace_all_uses_with(replace_node)
nodes_to_remove.append(user)
nodes_to_remove.append(node)
elif node.args[0] == torch.ops._C.rms_norm.default:
# manual replace for rms_norm
kwargs = node.kwargs
input = kwargs['input']
out = kwargs['out']
weight = kwargs['weight']
epsilon = kwargs['epsilon']
# Create a new call to torch.ops._C.rotary_embedding.default
# cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa
replace_node = kwargs['result']
# Create a new call to torch.ops._C.rms_norm.default
with graph.inserting_before(node):
# just insert the call to the custom op
# NOTE: don't run dead code elimination,
# otherwise this op will be removed
graph.call_function(torch.ops._C.rms_norm.default,
kwargs=kwargs)
for user in list(node.users):
if user.op == 'call_function' and user.target == operator.getitem: # noqa
user.replace_all_uses_with(replace_node)
nodes_to_remove.append(user)
nodes_to_remove.append(node)
elif node.args[
0] == torch.ops._C.rms_norm_static_fp8_quant.default: # noqa
# manual replace for rms_norm_static_fp8_quant
kwargs = node.kwargs
replace_node = kwargs['result']
# Create a new call to torch.ops._C.rms_norm_static_fp8_quant.default # noqa
with graph.inserting_before(node):
# just insert the call to the custom op
# NOTE: don't run dead code elimination,
# otherwise this op will be removed
graph.call_function(
torch.ops._C.rms_norm.default,
args=(out, input, weight, epsilon),
)
replace_node = out
torch.ops._C.rms_norm_static_fp8_quant.default,
kwargs=kwargs)
for user in list(node.users):
if user.op == 'call_function' and user.target == operator.getitem: # noqa
@ -136,7 +185,7 @@ def fix_functionalization(graph: fx.Graph):
input = kwargs['input']
out = kwargs['out']
# Create a new call to torch.ops._C.rotary_embedding.default
# Create a new call to torch.ops._C.silu_and_mul.default
# cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa
with graph.inserting_before(node):
# just insert the call to the custom op
@ -319,6 +368,13 @@ class VllmBackend:
The major work of this backend is to split the graph into
piecewise graphs, and pass them to the piecewise backend.
This backend also handles custom passes and adds them to Inductor config.
The order of the post-grad post-passes is:
1. post_grad_passes (constructor parameter)
2. config["post_grad_custom_post_pass"]
3. fix_functionalization
This way, all passes operate on a functionalized graph.
"""
compilation_configs: CompilationConfig
@ -330,8 +386,10 @@ class VllmBackend:
split_gm: fx.GraphModule
piecewise_graphs: List[SplitItem]
returned_callable: Callable
# Inductor passes to run on the graph pre-defunctionalization
post_grad_passes: Sequence[Callable]
def __init__(self, ):
def __init__(self, post_grad_passes: Sequence[Callable] = ()):
global global_graph_pool
if global_graph_pool is None:
global_graph_pool = torch.cuda.graph_pool_handle()
@ -340,10 +398,30 @@ class VllmBackend:
# streams, it might not be safe to share a global pool.
# only investigate this when we use multiple streams
self.graph_pool = global_graph_pool
self.post_grad_passes = post_grad_passes
# `torch.compile` is JIT compiled, so we don't need to
# do anything here
def add_passes_to_config(self):
config = self.compilation_configs
passes = list(self.post_grad_passes)
passes = passes + [RedundantReshapesPass(config)]
if config.enable_fusion:
passes = passes + [FusionPass.instance(config)]
inductor_config = config.inductor_compile_config
if "post_grad_custom_post_pass" in inductor_config:
passes = passes + [inductor_config["post_grad_custom_post_pass"]]
# add the fix_functionalization pass last, so that all other
# passes operate on a functionalized graph
passes = passes + [fix_functionalization]
combined_pass = combine_fx_passes(passes)
inductor_config["post_grad_custom_post_pass"] = combined_pass
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
compilation_counter.num_graphs_seen += 1
@ -357,6 +435,7 @@ class VllmBackend:
# we get the sizes to capture for cudagraph
# from compilation context
self.compilation_configs = CompilationConfig.select_and_init_config()
self.add_passes_to_config()
self.split_gm, self.piecewise_graphs = split_graph(
graph, self.compilation_configs.non_cudagraph_ops)

View File

@ -1,4 +1,5 @@
import copy
from pathlib import Path
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, PrivateAttr
@ -50,6 +51,12 @@ class CompilationConfig(BaseModel):
name because the config uses json format. If we pass the config
from Python, functions can also be passed directly via Python object
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`
- Custom inductor passes:
- dump_graph_stages: list of stages for which we want to dump the graph.
Each pass defines its own stages (before, after, maybe in-between).
- dump_graph_dir: directory to dump the graph. Default is .
- enable_fusion: whether to enable the custom fusion pass.
TODO better pass enabling system.
Why we have different sizes for cudagraph and inductor:
- cudagraph: a cudagraph captured for a specific size can only be used
@ -72,6 +79,10 @@ class CompilationConfig(BaseModel):
cudagraph_num_of_warmups: int = 0
cudagraph_capture_sizes: Optional[List[int]] = None
dump_graph_stages: List[str] = Field(default_factory=list)
dump_graph_dir: Path = Field(default=Path("."))
enable_fusion: bool = True
# not configurable, computed after init
compile_sizes: List[int] = PrivateAttr
capture_sizes: List[int] = PrivateAttr
@ -81,7 +92,7 @@ class CompilationConfig(BaseModel):
if not isinstance(v, str):
assert callable(v), (
f"pass {k} should be a function or a qualified name")
self.inductor_passes[k] = v
self.inductor_compile_config[k] = v
continue
# resolve function from qualified name
@ -91,18 +102,6 @@ class CompilationConfig(BaseModel):
func = __import__(module).__dict__[func_name]
self.inductor_compile_config[k] = func
from vllm.compilation.backends import fix_functionalization
from vllm.utils import combine_fx_passes
if "post_grad_custom_post_pass" in self.inductor_compile_config:
self.inductor_compile_config[
"post_grad_custom_post_pass"] = combine_fx_passes(
fix_functionalization,
self.inductor_compile_config["post_grad_custom_post_pass"],
)
else:
self.inductor_compile_config[
"post_grad_custom_post_pass"] = fix_functionalization
def init_during_runtime(self):
"""To complete the initialization of config,
we need to know the compile context, which is only available

291
vllm/compilation/fusion.py Normal file
View File

@ -0,0 +1,291 @@
import operator
from typing import Iterable, List, Optional
import torch
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import (Match, PatternMatcherPass,
fwd_only, register_replacement)
from vllm.compilation.config import CompilationConfig
from vllm.compilation.inductor_pass import InductorPass
from vllm.logger import init_logger
logger = init_logger(__name__)
def rms_pattern_static(result: torch.Tensor, result_rms: torch.Tensor,
input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at1 = auto_functionalized(torch.ops._C.rms_norm.default,
result=result_rms,
input=input,
weight=weight,
epsilon=1e-5)
at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default,
result=result,
input=at1[1],
scale=scale)
# result
return at2[1]
def rms_replacement_static(result: torch.Tensor, result_rms: torch.Tensor,
input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(torch.ops._C.rms_norm_static_fp8_quant.default,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=1e-5)
# result
return at[1]
def rms_pattern_residual_static(result: torch.Tensor, input: torch.Tensor,
residual: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(torch.ops._C.fused_add_rms_norm.default,
input=input,
residual=residual,
weight=weight,
epsilon=1e-5)
at1 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default,
result=result,
input=at[1],
scale=scale)
# result, residual
return at1[1], at[2]
def rms_replacement_residual_static(result: torch.Tensor, input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor, scale: torch.Tensor):
at = auto_functionalized(
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default,
result=result,
input=input,
residual=residual,
weight=weight,
scale=scale,
epsilon=1e-5)
# result, residual
return at[1], at[2]
def empty_bf16(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
def empty_fp8(*args, **kwargs):
fp8 = torch.float8_e4m3fn
return torch.empty(*args, **kwargs, dtype=fp8, device="cuda")
def empty_fp32(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
# Utilities for post-processing multi-output matches
def is_func(node: torch.fx.Node, target) -> bool:
return node.op == "call_function" and node.target == target
# Returns the first auto_functionalized node with the given op (if it exists)
def find_auto_fn_maybe(nodes: Iterable[torch.fx.Node],
op) -> Optional[torch.fx.Node]:
for node in nodes:
if is_func(node, auto_functionalized) and node.args[0] == op: # noqa
return node
return None
# Returns the first auto_functionalized node with the given op
def find_auto_fn(nodes: Iterable[torch.fx.Node], op) -> torch.fx.Node:
node = find_auto_fn_maybe(nodes, op)
assert node is not None, f"Could not find {op} in nodes {nodes}"
return node
# Returns the getitem node that extracts the idx-th element from node
# (if it exists)
def find_getitem_maybe(node: torch.fx.Node,
idx: int) -> Optional[torch.fx.Node]:
for user in node.users:
if is_func(user, operator.getitem) and user.args[1] == idx:
return user
return None
# Returns the getitem node that extracts the idx-th element from node
def find_getitem(node: torch.fx.Node, idx: int) -> torch.fx.Node:
ret = find_getitem_maybe(node, idx)
assert ret is not None, f"Could not find getitem {idx} in node {node}"
return ret
class FusionPass(InductorPass):
"""
This pass fuses a pre-defined set of custom ops into fused ops.
It uses the torch pattern matcher to find the patterns and replace them.
It also manually processes multi-output matches, as those are broken in
the torch pattern matcher.
Because patterns can only be registered once, the pass is a singleton.
This will be addressed in a future version of PyTorch:
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
"""
_instance: 'Optional[FusionPass]' = None
@classmethod
def instance(cls, config: CompilationConfig):
"""
Get the singleton instance of the FusionPass.
If the instance exists, the config is updated but
initialization is not repeated.
"""
if cls._instance is None:
cls._instance = FusionPass(config)
else:
cls._instance.config = config
return cls._instance
def __init__(self, config: CompilationConfig):
assert self.__class__._instance is None, \
"FusionPass singleton instance already exists"
super().__init__(config)
self.matches: List[Match] = []
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="fusion_pass")
# Fuse rms_norm + static_scaled_fp8_quant into
# rms_norm_static_fp8_quant
inputs = [
empty_fp8(5, 4),
empty_bf16(5, 4),
empty_bf16(5, 4),
empty_bf16(1, 5),
empty_fp32(1, 1)
]
register_replacement(rms_pattern_static, rms_replacement_static,
inputs, fwd_only, self.patterns)
# Fuse fused_add_rms_norm + static_scaled_fp8_quant into
# fused_add_rms_norm_static_fp8_quant
# Because pattern has 2 outputs, we need to manually process the match
# (see process_matches)
inputs = [
empty_fp8(5, 4),
empty_bf16(5, 4),
empty_bf16(5, 4),
empty_bf16(1, 5),
empty_fp32(1, 1)
]
register_replacement(rms_pattern_residual_static,
rms_replacement_residual_static,
inputs,
fwd_only,
self.patterns,
extra_check=lambda m: self.record_match(m))
def record_match(self, match: Match) -> bool:
# Hijack the extra_check to record the match and
# save it for post-processing.
self.matches.append(match)
# Return False to prevent automatic replacement.
return False
def process_matches(self, graph: torch.fx.Graph):
"""
Manually process multi-output matches and replace them with fused nodes.
This is necessary because the automatic replacement for multi-output
matches is broken: https://github.com/pytorch/pytorch/issues/137280
"""
for match in self.matches:
# To avoid use-before-definition errors, insert replacement nodes
# after the last node in the match.
# match.nodes is not guaranteed to be sorted.
# Find the last node in the match.
for last_node_in_match in reversed(graph.nodes):
if last_node_in_match in match.nodes:
break
else:
raise ValueError("No nodes in graph")
# Insert a new auto_functionalized node for the fused operation,
# as well as getitem nodes to extract the result and residual.
# The auto_functionalized node returns a tuple of
# (None, result, residual) - None is the function return value.
# The resulting graph looks like this:
# at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa
# result_node_new = at[1]
# residual_node_new = at[2]
with graph.inserting_after(last_node_in_match):
kwargs = match.kwargs
kwargs["epsilon"] = 1e-5 # Currently hard-coded in RMSNorm
fused_node = graph.call_function(
auto_functionalized,
(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default,
),
kwargs=kwargs)
graph.inserting_after(fused_node)
result_node_new = graph.call_function(operator.getitem,
(fused_node, 1))
residual_node_new = graph.call_function(
operator.getitem, (fused_node, 2))
# Last part of replacement is rebinding the users of nodes in the
# match to use the new nodes.
# Find the nodes in the match that we need to rebind
rms_node = find_auto_fn(match.nodes,
torch.ops._C.fused_add_rms_norm.default)
quant_node = find_auto_fn(
match.nodes, torch.ops._C.static_scaled_fp8_quant.default)
assert len(rms_node.users) == 2
assert len(quant_node.users) == 1
# meta["val"] is used by de-functionalization and has to contain the
# value of the node (tuple of tensors) that would be returned by the
# functionalized node during tracing.
rms_tup = rms_node.meta["val"]
quant_tup = quant_node.meta["val"]
# The result of fused_node must be a tuple with the first element
# None (the function return value) and the remaining elements
# representing the mutated inputs.
fused_tup = (None, quant_tup[1], rms_tup[1], rms_tup[2])
fused_node.meta["val"] = fused_tup
# Find the getitem nodes and replace their uses with the new nodes.
# The old nodes will be removed by DCE at the end of the pass.
find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new)
find_getitem(quant_node, 1).replace_all_uses_with(result_node_new)
# Finally, remove matched nodes
graph.eliminate_dead_code()
assert all(node not in graph.nodes for match in self.matches
for node in match.nodes)
def __call__(self, graph: torch.fx.Graph):
self.dump_graph(graph, "before_fusion")
count = self.patterns.apply(graph)
logger.info("Replaced %s patterns", count)
self.dump_graph(graph, "after_pattern_match")
# Manually process multi-output matches (and run DCE)
self.process_matches(graph)
logger.info("Post-processed %s matches", len(self.matches))
self.dump_graph(graph, "after_fusion")
self.matches.clear()

View File

@ -0,0 +1,38 @@
from abc import ABC, abstractmethod
import torch
from vllm.compilation.config import CompilationConfig
# yapf: disable
from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
from vllm.distributed import (
get_tensor_model_parallel_world_size as get_tp_world_size)
from vllm.distributed import model_parallel_is_initialized as p_is_init
# yapf: enable
from vllm.logger import init_logger
logger = init_logger(__name__)
class InductorPass(ABC):
@abstractmethod
def __call__(self, graph: torch.fx.Graph):
raise NotImplementedError
def __init__(self, config: CompilationConfig):
self.config = config
def dump_graph(self, graph: torch.fx.Graph, stage: str):
if stage in self.config.dump_graph_stages:
# Make sure filename includes rank in the distributed setting
parallel = p_is_init() and get_tp_world_size() > 1
rank = f"-{get_tp_rank()}" if parallel else ""
filepath = self.config.dump_graph_dir / f"{stage}{rank}.py"
logger.info("Printing graph to %s", filepath)
with open(filepath, "w") as f:
src = graph.python_code(root_module="self", verbose=True).src
# Add imports so it's not full of errors
print("import torch; from torch import device", file=f)
print(src, file=f)

View File

@ -0,0 +1,85 @@
from typing import Union
import torch.fx
from torch import SymInt
from vllm.compilation.fusion import is_func
from vllm.compilation.inductor_pass import InductorPass
from vllm.logger import init_logger
logger = init_logger(__name__)
class RedundantReshapesPass(InductorPass):
"""
This is an inductor pass that removes redundant reshape operations.
It is required for RMSNorm-quant fusion to work properly.
That's because apply_fp8_linear adds a reshape, which is redundant
in the 2D-case.
Example graph:
getitem_1: "f16[s0, 4096]" = ...
view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096])
at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...)
out: "f8e4m3fn[s0, 4096]" = at[1]
Can be replaced with:
getitem_1: "f16[s0, 4096]" = ...
at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...)
out: "f8e4m3fn[s0, 4096]" = at[1]
"""
def __call__(self, graph: torch.fx.Graph):
self.dump_graph(graph, "before_reshapes")
count = 0
# Remove no-op reshapes/views:
for node in graph.nodes:
if is_func(node, torch.ops.aten.reshape.default):
input, shape = node.args[:2]
input_shape = input.meta["val"].shape
if len(shape) != len(input_shape):
# Reshape changing rank, skip
continue
if shape.count(-1) > 1:
# Invalid reshape args, skip
continue
if all(
self.dims_equivalent(s, i_s)
for s, i_s in zip(shape, input_shape)):
node.replace_all_uses_with(input)
graph.erase_node(node)
count += 1
logger.info("Removed %s no-op reshapes", count)
self.dump_graph(graph, "after_reshapes")
def dims_equivalent(self, dim: Union[int, torch.fx.Node],
i_dim: Union[int, SymInt]) -> bool:
"""
This function checks if two dimensions are equivalent.
:param dim: The dimension arg to reshape
:param i_dim: The corresponding dimension in the input tensor
:return: Are the dimensions equivalent?
There are three cases in which the dimensions are equivalent:
1. The dimensions are equal (both integers)
2. The reshape dimension is -1 (i.e. inferred)
3. The dimensions both correspond to the same SymInt
While case 2 does not guarantee the dimensions are equal,
they are equal if all other dimensions are equal.
In case 3, the reshape dimension is a torch.fx.Node,
and its value is a SymInt. That value is equal to the
input dimension.
"""
# Case 1 and 2
if dim == i_dim or dim == -1:
return True
# Case 3
return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim

View File

@ -68,6 +68,7 @@ if TYPE_CHECKING:
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False
VLLM_TORCH_COMPILE_LEVEL: int = 0
VLLM_TORCH_COMPILE_CONFIG: Optional[str] = None
VLLM_CUSTOM_OPS: List[str] = []
VLLM_DISABLED_KERNELS: List[str] = []
VLLM_USE_V1: bool = False
@ -226,6 +227,7 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# and disabled when running with Inductor (compile_level >= Inductor).
"VLLM_CUSTOM_OPS":
lambda: os.environ.get("VLLM_CUSTOM_OPS", "").replace(" ", "").split(","),
# local rank of the process in the distributed setting, used to determine
# the GPU device id
"LOCAL_RANK":