[torch.compile] Fuse RMSNorm with quant (#9138)
Signed-off-by: luka <luka@neuralmagic.com> Co-authored-by: youkaichao <youkaichao@126.com>
This commit is contained in:
parent
e1b5a82179
commit
4f93dfe952
@ -191,6 +191,7 @@ set(VLLM_EXT_SRC
|
||||
"csrc/pos_encoding_kernels.cu"
|
||||
"csrc/activation_kernels.cu"
|
||||
"csrc/layernorm_kernels.cu"
|
||||
"csrc/layernorm_quant_kernels.cu"
|
||||
"csrc/quantization/gptq/q_gemm.cu"
|
||||
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
|
||||
"csrc/quantization/fp8/common.cu"
|
||||
|
@ -1,21 +1,13 @@
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include "type_convert.cuh"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include <torch/cuda.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "dispatch_utils.h"
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cub/util_type.cuh>
|
||||
#include <cub/cub.cuh>
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hipcub/util_type.hpp>
|
||||
#include <hipcub/hipcub.hpp>
|
||||
|
||||
using __nv_bfloat16 = __hip_bfloat16;
|
||||
using __nv_bfloat162 = __hip_bfloat162;
|
||||
#endif
|
||||
|
||||
namespace vllm {
|
||||
@ -51,155 +43,6 @@ __global__ void rms_norm_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
/* Converter structs for the conversion from torch types to HIP/CUDA types,
|
||||
and the associated type conversions within HIP/CUDA. These helpers need
|
||||
to be implemented for now because the relevant type conversion
|
||||
operators/constructors are not consistently implemented by HIP/CUDA, so
|
||||
a generic conversion via type casts cannot be implemented.
|
||||
|
||||
Each struct should have the member static constexpr bool `exists`:
|
||||
If false, the optimized kernel is not used for the corresponding torch type.
|
||||
If true, the struct should be fully defined as shown in the examples below.
|
||||
*/
|
||||
template <typename torch_type>
|
||||
struct _typeConvert {
|
||||
static constexpr bool exists = false;
|
||||
};
|
||||
|
||||
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
|
||||
// CUDA < 12.0 runs into issues with packed type conversion
|
||||
template <>
|
||||
struct _typeConvert<c10::Half> {
|
||||
static constexpr bool exists = true;
|
||||
using hip_type = __half;
|
||||
using packed_hip_type = __half2;
|
||||
|
||||
__device__ static inline float convert(hip_type x) { return __half2float(x); }
|
||||
__device__ static inline float2 convert(packed_hip_type x) {
|
||||
return __half22float2(x);
|
||||
}
|
||||
__device__ static inline hip_type convert(float x) {
|
||||
return __float2half_rn(x);
|
||||
}
|
||||
__device__ static inline packed_hip_type convert(float2 x) {
|
||||
return __float22half2_rn(x);
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
// CUDA_ARCH < 800 does not have BF16 support
|
||||
// TODO: Add in ROCm support once public headers handle bf16 maturely
|
||||
template <>
|
||||
struct _typeConvert<c10::BFloat16> {
|
||||
static constexpr bool exists = true;
|
||||
using hip_type = __nv_bfloat16;
|
||||
using packed_hip_type = __nv_bfloat162;
|
||||
|
||||
__device__ static inline float convert(hip_type x) {
|
||||
return __bfloat162float(x);
|
||||
}
|
||||
__device__ static inline float2 convert(packed_hip_type x) {
|
||||
return __bfloat1622float2(x);
|
||||
}
|
||||
__device__ static inline hip_type convert(float x) {
|
||||
return __float2bfloat16(x);
|
||||
}
|
||||
__device__ static inline packed_hip_type convert(float2 x) {
|
||||
return __float22bfloat162_rn(x);
|
||||
}
|
||||
};
|
||||
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
|
||||
// 12000))
|
||||
|
||||
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
|
||||
for appropriate specializations of fused_add_rms_norm_kernel.
|
||||
Only functions that are necessary in that kernel are implemented.
|
||||
Alignment to 16 bytes is required to use 128-bit global memory ops.
|
||||
*/
|
||||
template <typename scalar_t, int width>
|
||||
struct alignas(16) _f16Vec {
|
||||
/* Not theoretically necessary that width is a power of 2 but should
|
||||
almost always be the case for optimization purposes */
|
||||
static_assert(width > 0 && (width & (width - 1)) == 0,
|
||||
"Width is not a positive power of 2!");
|
||||
using Converter = _typeConvert<scalar_t>;
|
||||
using T1 = typename Converter::hip_type;
|
||||
using T2 = typename Converter::packed_hip_type;
|
||||
T1 data[width];
|
||||
|
||||
__device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
|
||||
if constexpr (width % 2 == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; i += 2) {
|
||||
T2 temp{data[i], data[i + 1]};
|
||||
temp += T2{other.data[i], other.data[i + 1]};
|
||||
data[i] = temp.x;
|
||||
data[i + 1] = temp.y;
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; ++i) data[i] += other.data[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
|
||||
if constexpr (width % 2 == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; i += 2) {
|
||||
T2 temp{data[i], data[i + 1]};
|
||||
temp *= T2{other.data[i], other.data[i + 1]};
|
||||
data[i] = temp.x;
|
||||
data[i + 1] = temp.y;
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; ++i) data[i] *= other.data[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ _f16Vec& operator*=(const float scale) {
|
||||
if constexpr (width % 2 == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; i += 2) {
|
||||
float2 temp_f = Converter::convert(T2{data[i], data[i + 1]});
|
||||
temp_f.x *= scale;
|
||||
temp_f.y *= scale;
|
||||
T2 temp = Converter::convert(temp_f);
|
||||
data[i] = temp.x;
|
||||
data[i + 1] = temp.y;
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; ++i) {
|
||||
float temp = Converter::convert(data[i]) * scale;
|
||||
data[i] = Converter::convert(temp);
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ float sum_squares() const {
|
||||
float result = 0.0f;
|
||||
if constexpr (width % 2 == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; i += 2) {
|
||||
float2 z = Converter::convert(T2{data[i], data[i + 1]});
|
||||
result += z.x * z.x + z.y * z.y;
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; ++i) {
|
||||
float x = Converter::convert(data[i]);
|
||||
result += x * x;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
/* Function specialization in the case of FP16/BF16 tensors.
|
||||
Additional optimizations we can make in this case are
|
||||
packed and vectorized operations, which help with the
|
||||
|
234
csrc/layernorm_quant_kernels.cu
Normal file
234
csrc/layernorm_quant_kernels.cu
Normal file
@ -0,0 +1,234 @@
|
||||
/*
|
||||
* This file contains the CUDA kernels for the fused quantized layernorm.
|
||||
* The kernels correspond to the kernels in layernorm_kernels.cu, except they
|
||||
* also produce quantized output directly.
|
||||
* Currently, only static fp8 quantization is supported.
|
||||
*/
|
||||
|
||||
#include "type_convert.cuh"
|
||||
#include "quantization/fp8/common.cuh"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include <torch/cuda.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/cub.cuh>
|
||||
#else
|
||||
#include <hipcub/hipcub.hpp>
|
||||
#endif
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// TODO(woosuk): Further optimize this kernel.
|
||||
template <typename scalar_t>
|
||||
__global__ void rms_norm_static_fp8_quant_kernel(
|
||||
FP8_TYPE* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float* __restrict__ scale, // [1]
|
||||
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
const float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
variance += x * x;
|
||||
}
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// invert scale to avoid division
|
||||
float const scale_inv = 1.0f / *scale;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
|
||||
out[blockIdx.x * hidden_size + idx] =
|
||||
scaled_fp8_conversion<true>(out_norm, scale_inv);
|
||||
}
|
||||
}
|
||||
|
||||
/* Function specialization in the case of FP16/BF16 tensors.
|
||||
Additional optimizations we can make in this case are
|
||||
packed and vectorized operations, which help with the
|
||||
memory latency bottleneck. */
|
||||
template <typename scalar_t, int width>
|
||||
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
|
||||
fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
FP8_TYPE* __restrict__ out, // [..., hidden_size]
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float* __restrict__ scale, // [1]
|
||||
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||
// Sanity checks on our vector struct and type-punned pointer arithmetic
|
||||
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
|
||||
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
|
||||
|
||||
const int vec_hidden_size = hidden_size / width;
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
/* These and the argument pointers are all declared `restrict` as they are
|
||||
not aliased in practice. Argument pointers should not be dereferenced
|
||||
in this kernel as that would be undefined behavior */
|
||||
auto* __restrict__ input_v =
|
||||
reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
|
||||
auto* __restrict__ residual_v =
|
||||
reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
|
||||
auto* __restrict__ weight_v =
|
||||
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
|
||||
|
||||
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||
int id = blockIdx.x * vec_hidden_size + idx;
|
||||
_f16Vec<scalar_t, width> temp = input_v[id];
|
||||
temp += residual_v[id];
|
||||
variance += temp.sum_squares();
|
||||
residual_v[id] = temp;
|
||||
}
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// invert scale to avoid division
|
||||
float const scale_inv = 1.0f / *scale;
|
||||
|
||||
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||
int id = blockIdx.x * vec_hidden_size + idx;
|
||||
_f16Vec<scalar_t, width> temp = residual_v[id];
|
||||
temp *= s_variance;
|
||||
temp *= weight_v[idx];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; ++i) {
|
||||
out[id * width + i] =
|
||||
scaled_fp8_conversion<true>(float(temp.data[i]), scale_inv);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* Generic fused_add_rms_norm_kernel
|
||||
The width field is not used here but necessary for other specializations.
|
||||
*/
|
||||
template <typename scalar_t, int width>
|
||||
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
|
||||
fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
FP8_TYPE* __restrict__ out, // [..., hidden_size]
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float* __restrict__ scale, // [1]
|
||||
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
scalar_t z = input[blockIdx.x * hidden_size + idx];
|
||||
z += residual[blockIdx.x * hidden_size + idx];
|
||||
float x = (float)z;
|
||||
variance += x * x;
|
||||
residual[blockIdx.x * hidden_size + idx] = z;
|
||||
}
|
||||
|
||||
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// invert scale to avoid division
|
||||
float const scale_inv = 1.0f / *scale;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float)residual[blockIdx.x * hidden_size + idx];
|
||||
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
|
||||
out[blockIdx.x * hidden_size + idx] =
|
||||
scaled_fp8_conversion<true>(out_norm, scale_inv);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
torch::Tensor& scale, // [1]
|
||||
double epsilon) {
|
||||
int hidden_size = input.size(-1);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
|
||||
vllm::rms_norm_static_fp8_quant_kernel<scalar_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), epsilon,
|
||||
num_tokens, hidden_size);
|
||||
});
|
||||
}
|
||||
|
||||
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
|
||||
vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, width> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(), \
|
||||
residual.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \
|
||||
scale.data_ptr<float>(), epsilon, num_tokens, hidden_size); \
|
||||
});
|
||||
|
||||
void fused_add_rms_norm_static_fp8_quant(
|
||||
torch::Tensor& out, // [..., hidden_size],
|
||||
torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& residual, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
torch::Tensor& scale, // [1]
|
||||
double epsilon) {
|
||||
int hidden_size = input.size(-1);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
/* This kernel is memory-latency bound in many scenarios.
|
||||
When num_tokens is large, a smaller block size allows
|
||||
for increased block occupancy on CUs and better latency
|
||||
hiding on global mem ops. */
|
||||
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
|
||||
dim3 block(std::min(hidden_size, max_block_size));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
/*If the tensor types are FP16/BF16, try to use the optimized kernel
|
||||
with packed + vectorized ops.
|
||||
Max optimization is achieved with a width-8 vector of FP16/BF16s
|
||||
since we can load at most 128 bits at once in a global memory op.
|
||||
However, this requires each tensor's data to be aligned to 16
|
||||
bytes.
|
||||
*/
|
||||
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
||||
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
|
||||
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
||||
bool ptrs_are_aligned =
|
||||
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
|
||||
if (ptrs_are_aligned && hidden_size % 8 == 0) {
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||
} else {
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
||||
}
|
||||
}
|
10
csrc/ops.h
10
csrc/ops.h
@ -56,6 +56,16 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
||||
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
|
||||
torch::Tensor& weight, double epsilon);
|
||||
|
||||
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
||||
torch::Tensor& weight, torch::Tensor& scale,
|
||||
double epsilon);
|
||||
|
||||
void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out,
|
||||
torch::Tensor& input,
|
||||
torch::Tensor& residual,
|
||||
torch::Tensor& weight,
|
||||
torch::Tensor& scale, double epsilon);
|
||||
|
||||
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||
torch::Tensor& key, int64_t head_size,
|
||||
torch::Tensor& cos_sin_cache, bool is_neox);
|
||||
|
@ -1,185 +1,16 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/all.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "common.cuh"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/util_type.cuh>
|
||||
#include <cub/cub.cuh>
|
||||
#else
|
||||
#include <hipcub/util_type.hpp>
|
||||
#include <hipcub/hipcub.hpp>
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
using FP8_TYPE = c10::Float8_e4m3fn;
|
||||
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
|
||||
std::numeric_limits<FP8_TYPE>::max();
|
||||
#else
|
||||
#include "amd/hip_float8.h"
|
||||
using FP8_TYPE = c10::Float8_e4m3fnuz;
|
||||
// Using the default max value from pytorch (240.0) will cause accuracy
|
||||
// issue when running dynamic quantization. Here use 224.0f for rocm.
|
||||
constexpr auto FP8_E4M3_MAX = 224.0f;
|
||||
#endif
|
||||
|
||||
namespace vllm {
|
||||
|
||||
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||
float old;
|
||||
old = (value >= 0)
|
||||
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
|
||||
: __uint_as_float(
|
||||
atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
||||
|
||||
return old;
|
||||
}
|
||||
|
||||
template <bool is_scale_inverted>
|
||||
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
|
||||
float const scale) {
|
||||
float x = 0.0f;
|
||||
if constexpr (is_scale_inverted) {
|
||||
x = val * scale;
|
||||
} else {
|
||||
x = val / scale;
|
||||
}
|
||||
|
||||
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
||||
#ifndef USE_ROCM
|
||||
return static_cast<c10::Float8_e4m3fn>(r);
|
||||
#else
|
||||
// Use hardware cvt instruction for fp8 on rocm
|
||||
return c10::Float8_e4m3fnuz(hip_fp8(r).data,
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
#endif
|
||||
}
|
||||
|
||||
// Compute the absolute maximum m of the input tensor and store
|
||||
// m / float8_e4m3::max() in *scale. Each thread block performs a
|
||||
// reduction tree and the memory in scale is atomically updated.
|
||||
// So to get the right answer, *scale needs to be initialized to
|
||||
// a value <= 0.0 and we need to wait for all thread blocks to
|
||||
// finish before consuming *scale.
|
||||
template <typename scalar_t>
|
||||
__global__ void segmented_max_reduction(float* __restrict__ scale,
|
||||
const scalar_t* __restrict__ input,
|
||||
int64_t num_elems) {
|
||||
__shared__ float cache[1024];
|
||||
int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
// First store maximum for all values processes by
|
||||
// the current thread in cache[threadIdx.x]
|
||||
scalar_t tmp = 0.0;
|
||||
while (i < num_elems) {
|
||||
float x = static_cast<float>(input[i]);
|
||||
tmp = max(tmp, fabs(x));
|
||||
i += blockDim.x * gridDim.x;
|
||||
}
|
||||
cache[threadIdx.x] = tmp;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Now perform parallel reduction within the thread block
|
||||
int ib = blockDim.x / 2;
|
||||
while (ib != 0) {
|
||||
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
|
||||
cache[threadIdx.x] = cache[threadIdx.x + ib];
|
||||
}
|
||||
__syncthreads();
|
||||
ib /= 2;
|
||||
}
|
||||
// Finally, since cache[0] contains the maximum for this thread block,
|
||||
// atomically write the max to the target location
|
||||
if (threadIdx.x == 0) {
|
||||
atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
struct __align__(8) vec4_t {
|
||||
scalar_t x;
|
||||
scalar_t y;
|
||||
scalar_t z;
|
||||
scalar_t w;
|
||||
};
|
||||
|
||||
typedef struct __align__(4) {
|
||||
FP8_TYPE x;
|
||||
FP8_TYPE y;
|
||||
FP8_TYPE z;
|
||||
FP8_TYPE w;
|
||||
}
|
||||
float8x4_t;
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
|
||||
int64_t const num_elems, int const tid,
|
||||
int const step) {
|
||||
// Vectorized input/output to better utilize memory bandwidth.
|
||||
vec4_t<scalar_t> const* vectorized_in =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
||||
|
||||
int64_t const num_vec_elems = num_elems >> 2;
|
||||
float absmax_val = 0.0f;
|
||||
|
||||
#pragma unroll 4
|
||||
for (int64_t i = tid; i < num_vec_elems; i += step) {
|
||||
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
||||
absmax_val = max(absmax_val, fabs(in_vec.x));
|
||||
absmax_val = max(absmax_val, fabs(in_vec.y));
|
||||
absmax_val = max(absmax_val, fabs(in_vec.z));
|
||||
absmax_val = max(absmax_val, fabs(in_vec.w));
|
||||
}
|
||||
|
||||
// Handle the remaining elements if num_elems is not divisible by 4
|
||||
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
||||
absmax_val = max(absmax_val, fabs(input[i]));
|
||||
}
|
||||
|
||||
return absmax_val;
|
||||
}
|
||||
|
||||
template <typename scalar_t, bool is_scale_inverted>
|
||||
__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
|
||||
scalar_t const* __restrict__ input,
|
||||
float const scale,
|
||||
int64_t const num_elems,
|
||||
int const tid, int const step) {
|
||||
// Vectorized input/output to better utilize memory bandwidth.
|
||||
vec4_t<scalar_t> const* vectorized_in =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
||||
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
|
||||
|
||||
int64_t const num_vec_elems = num_elems >> 2;
|
||||
|
||||
#pragma unroll 4
|
||||
for (int64_t i = tid; i < num_vec_elems; i += step) {
|
||||
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
||||
float8x4_t out_vec;
|
||||
|
||||
out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
|
||||
static_cast<float>(in_vec.x), scale);
|
||||
out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
|
||||
static_cast<float>(in_vec.y), scale);
|
||||
out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
|
||||
static_cast<float>(in_vec.z), scale);
|
||||
out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
|
||||
static_cast<float>(in_vec.w), scale);
|
||||
vectorized_out[i] = out_vec;
|
||||
}
|
||||
|
||||
// Handle the remaining elements if num_elems is not divisible by 4
|
||||
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
||||
out[i] = scaled_fp8_conversion<is_scale_inverted>(
|
||||
static_cast<float>(input[i]), scale);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
|
||||
const scalar_t* __restrict__ input,
|
||||
|
172
csrc/quantization/fp8/common.cuh
Normal file
172
csrc/quantization/fp8/common.cuh
Normal file
@ -0,0 +1,172 @@
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
using FP8_TYPE = c10::Float8_e4m3fn;
|
||||
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
|
||||
std::numeric_limits<FP8_TYPE>::max();
|
||||
#else
|
||||
#include <c10/util/Float8_e4m3fnuz.h>
|
||||
#include "amd/hip_float8.h"
|
||||
using FP8_TYPE = c10::Float8_e4m3fnuz;
|
||||
// Using the default max value from pytorch (240.0) will cause accuracy
|
||||
// issue when running dynamic quantization. Here use 224.0f for rocm.
|
||||
constexpr auto FP8_E4M3_MAX = 224.0f;
|
||||
#endif
|
||||
|
||||
namespace vllm {
|
||||
|
||||
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||
float old;
|
||||
old = (value >= 0)
|
||||
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
|
||||
: __uint_as_float(
|
||||
atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
||||
|
||||
return old;
|
||||
}
|
||||
|
||||
template <bool is_scale_inverted>
|
||||
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
|
||||
float const scale) {
|
||||
float x = 0.0f;
|
||||
if constexpr (is_scale_inverted) {
|
||||
x = val * scale;
|
||||
} else {
|
||||
x = val / scale;
|
||||
}
|
||||
|
||||
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
||||
#ifndef USE_ROCM
|
||||
return static_cast<c10::Float8_e4m3fn>(r);
|
||||
#else
|
||||
// Use hardware cvt instruction for fp8 on rocm
|
||||
return c10::Float8_e4m3fnuz(hip_fp8(r).data,
|
||||
c10::Float8_e4m3fnuz::from_bits());
|
||||
#endif
|
||||
}
|
||||
|
||||
// Compute the absolute maximum m of the input tensor and store
|
||||
// m / float8_e4m3::max() in *scale. Each thread block performs a
|
||||
// reduction tree and the memory in scale is atomically updated.
|
||||
// So to get the right answer, *scale needs to be initialized to
|
||||
// a value <= 0.0 and we need to wait for all thread blocks to
|
||||
// finish before consuming *scale.
|
||||
template <typename scalar_t>
|
||||
__global__ void segmented_max_reduction(float* __restrict__ scale,
|
||||
const scalar_t* __restrict__ input,
|
||||
int64_t num_elems) {
|
||||
__shared__ float cache[1024];
|
||||
int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
// First store maximum for all values processes by
|
||||
// the current thread in cache[threadIdx.x]
|
||||
scalar_t tmp = 0.0;
|
||||
while (i < num_elems) {
|
||||
float x = static_cast<float>(input[i]);
|
||||
tmp = max(tmp, fabs(x));
|
||||
i += blockDim.x * gridDim.x;
|
||||
}
|
||||
cache[threadIdx.x] = tmp;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Now perform parallel reduction within the thread block
|
||||
int ib = blockDim.x / 2;
|
||||
while (ib != 0) {
|
||||
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
|
||||
cache[threadIdx.x] = cache[threadIdx.x + ib];
|
||||
}
|
||||
__syncthreads();
|
||||
ib /= 2;
|
||||
}
|
||||
// Finally, since cache[0] contains the maximum for this thread block,
|
||||
// atomically write the max to the target location
|
||||
if (threadIdx.x == 0) {
|
||||
atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
struct __align__(8) vec4_t {
|
||||
scalar_t x;
|
||||
scalar_t y;
|
||||
scalar_t z;
|
||||
scalar_t w;
|
||||
};
|
||||
|
||||
typedef struct __align__(4) {
|
||||
FP8_TYPE x;
|
||||
FP8_TYPE y;
|
||||
FP8_TYPE z;
|
||||
FP8_TYPE w;
|
||||
}
|
||||
float8x4_t;
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
|
||||
int64_t const num_elems, int const tid,
|
||||
int const step) {
|
||||
// Vectorized input/output to better utilize memory bandwidth.
|
||||
vec4_t<scalar_t> const* vectorized_in =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
||||
|
||||
int64_t const num_vec_elems = num_elems >> 2;
|
||||
float absmax_val = 0.0f;
|
||||
|
||||
#pragma unroll 4
|
||||
for (int64_t i = tid; i < num_vec_elems; i += step) {
|
||||
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
||||
absmax_val = max(absmax_val, fabs(in_vec.x));
|
||||
absmax_val = max(absmax_val, fabs(in_vec.y));
|
||||
absmax_val = max(absmax_val, fabs(in_vec.z));
|
||||
absmax_val = max(absmax_val, fabs(in_vec.w));
|
||||
}
|
||||
|
||||
// Handle the remaining elements if num_elems is not divisible by 4
|
||||
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
||||
absmax_val = max(absmax_val, fabs(input[i]));
|
||||
}
|
||||
|
||||
return absmax_val;
|
||||
}
|
||||
|
||||
template <typename scalar_t, bool is_scale_inverted>
|
||||
__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
|
||||
scalar_t const* __restrict__ input,
|
||||
float const scale,
|
||||
int64_t const num_elems,
|
||||
int const tid, int const step) {
|
||||
// Vectorized input/output to better utilize memory bandwidth.
|
||||
vec4_t<scalar_t> const* vectorized_in =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
||||
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
|
||||
|
||||
int64_t const num_vec_elems = num_elems >> 2;
|
||||
|
||||
#pragma unroll 4
|
||||
for (int64_t i = tid; i < num_vec_elems; i += step) {
|
||||
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
||||
float8x4_t out_vec;
|
||||
|
||||
out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
|
||||
static_cast<float>(in_vec.x), scale);
|
||||
out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
|
||||
static_cast<float>(in_vec.y), scale);
|
||||
out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
|
||||
static_cast<float>(in_vec.z), scale);
|
||||
out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
|
||||
static_cast<float>(in_vec.w), scale);
|
||||
vectorized_out[i] = out_vec;
|
||||
}
|
||||
|
||||
// Handle the remaining elements if num_elems is not divisible by 4
|
||||
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
||||
out[i] = scaled_fp8_conversion<is_scale_inverted>(
|
||||
static_cast<float>(input[i]), scale);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
@ -101,7 +101,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// Layernorm
|
||||
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
||||
ops.def(
|
||||
"rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
|
||||
"rms_norm(Tensor! result, Tensor input, Tensor weight, float epsilon) -> "
|
||||
"()");
|
||||
ops.impl("rms_norm", torch::kCUDA, &rms_norm);
|
||||
|
||||
@ -111,6 +111,23 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"float epsilon) -> ()");
|
||||
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
|
||||
|
||||
// Layernorm-quant
|
||||
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
||||
ops.def(
|
||||
"rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, "
|
||||
"Tensor scale, float epsilon) -> "
|
||||
"()");
|
||||
ops.impl("rms_norm_static_fp8_quant", torch::kCUDA,
|
||||
&rms_norm_static_fp8_quant);
|
||||
|
||||
// In-place fused Add and RMS Normalization.
|
||||
ops.def(
|
||||
"fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, "
|
||||
"Tensor! residual, Tensor weight, "
|
||||
"Tensor scale, float epsilon) -> ()");
|
||||
ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA,
|
||||
&fused_add_rms_norm_static_fp8_quant);
|
||||
|
||||
// Rotary embedding
|
||||
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
|
||||
ops.def(
|
||||
@ -322,18 +339,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
|
||||
// Compute FP8 quantized tensor for given scaling factor.
|
||||
ops.def(
|
||||
"static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()");
|
||||
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
|
||||
"()");
|
||||
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
|
||||
|
||||
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
|
||||
ops.def(
|
||||
"dynamic_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
|
||||
"dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
|
||||
"-> "
|
||||
"()");
|
||||
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
|
||||
|
||||
// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
|
||||
ops.def(
|
||||
"dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, "
|
||||
"dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
|
||||
"Tensor! scale, Tensor? scale_ub) -> "
|
||||
"()");
|
||||
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
|
||||
@ -341,13 +360,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
|
||||
// Compute int8 quantized tensor for given scaling factor.
|
||||
ops.def(
|
||||
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
|
||||
"static_scaled_int8_quant(Tensor! result, Tensor input, Tensor scale,"
|
||||
"Tensor? azp) -> ()");
|
||||
ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant);
|
||||
|
||||
// Compute int8 quantized tensor and scaling factor
|
||||
ops.def(
|
||||
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
|
||||
"dynamic_scaled_int8_quant(Tensor! result, Tensor input, Tensor! scale, "
|
||||
"Tensor!? azp) -> ()");
|
||||
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
|
||||
&dynamic_scaled_int8_quant);
|
||||
|
165
csrc/type_convert.cuh
Normal file
165
csrc/type_convert.cuh
Normal file
@ -0,0 +1,165 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
using __nv_bfloat16 = __hip_bfloat16;
|
||||
using __nv_bfloat162 = __hip_bfloat162;
|
||||
#endif
|
||||
|
||||
namespace vllm {
|
||||
/* Converter structs for the conversion from torch types to HIP/CUDA types,
|
||||
and the associated type conversions within HIP/CUDA. These helpers need
|
||||
to be implemented for now because the relevant type conversion
|
||||
operators/constructors are not consistently implemented by HIP/CUDA, so
|
||||
a generic conversion via type casts cannot be implemented.
|
||||
|
||||
Each struct should have the member static constexpr bool `exists`:
|
||||
If false, the optimized kernel is not used for the corresponding torch type.
|
||||
If true, the struct should be fully defined as shown in the examples below.
|
||||
*/
|
||||
template <typename torch_type>
|
||||
struct _typeConvert {
|
||||
static constexpr bool exists = false;
|
||||
};
|
||||
|
||||
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
|
||||
// CUDA < 12.0 runs into issues with packed type conversion
|
||||
template <>
|
||||
struct _typeConvert<c10::Half> {
|
||||
static constexpr bool exists = true;
|
||||
using hip_type = __half;
|
||||
using packed_hip_type = __half2;
|
||||
|
||||
__device__ static inline float convert(hip_type x) { return __half2float(x); }
|
||||
__device__ static inline float2 convert(packed_hip_type x) {
|
||||
return __half22float2(x);
|
||||
}
|
||||
__device__ static inline hip_type convert(float x) {
|
||||
return __float2half_rn(x);
|
||||
}
|
||||
__device__ static inline packed_hip_type convert(float2 x) {
|
||||
return __float22half2_rn(x);
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
// CUDA_ARCH < 800 does not have BF16 support
|
||||
// TODO: Add in ROCm support once public headers handle bf16 maturely
|
||||
template <>
|
||||
struct _typeConvert<c10::BFloat16> {
|
||||
static constexpr bool exists = true;
|
||||
using hip_type = __nv_bfloat16;
|
||||
using packed_hip_type = __nv_bfloat162;
|
||||
|
||||
__device__ static inline float convert(hip_type x) {
|
||||
return __bfloat162float(x);
|
||||
}
|
||||
__device__ static inline float2 convert(packed_hip_type x) {
|
||||
return __bfloat1622float2(x);
|
||||
}
|
||||
__device__ static inline hip_type convert(float x) {
|
||||
return __float2bfloat16(x);
|
||||
}
|
||||
__device__ static inline packed_hip_type convert(float2 x) {
|
||||
return __float22bfloat162_rn(x);
|
||||
}
|
||||
};
|
||||
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
|
||||
// 12000))
|
||||
|
||||
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
|
||||
for appropriate specializations of fused_add_rms_norm_kernel.
|
||||
Only functions that are necessary in that kernel are implemented.
|
||||
Alignment to 16 bytes is required to use 128-bit global memory ops.
|
||||
*/
|
||||
template <typename scalar_t, int width>
|
||||
struct alignas(16) _f16Vec {
|
||||
/* Not theoretically necessary that width is a power of 2 but should
|
||||
almost always be the case for optimization purposes */
|
||||
static_assert(width > 0 && (width & (width - 1)) == 0,
|
||||
"Width is not a positive power of 2!");
|
||||
using Converter = _typeConvert<scalar_t>;
|
||||
using T1 = typename Converter::hip_type;
|
||||
using T2 = typename Converter::packed_hip_type;
|
||||
T1 data[width];
|
||||
|
||||
__device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
|
||||
if constexpr (width % 2 == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; i += 2) {
|
||||
T2 temp{data[i], data[i + 1]};
|
||||
temp += T2{other.data[i], other.data[i + 1]};
|
||||
data[i] = temp.x;
|
||||
data[i + 1] = temp.y;
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; ++i) data[i] += other.data[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
|
||||
if constexpr (width % 2 == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; i += 2) {
|
||||
T2 temp{data[i], data[i + 1]};
|
||||
temp *= T2{other.data[i], other.data[i + 1]};
|
||||
data[i] = temp.x;
|
||||
data[i + 1] = temp.y;
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; ++i) data[i] *= other.data[i];
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ _f16Vec& operator*=(const float scale) {
|
||||
if constexpr (width % 2 == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; i += 2) {
|
||||
float2 temp_f = Converter::convert(T2{data[i], data[i + 1]});
|
||||
temp_f.x *= scale;
|
||||
temp_f.y *= scale;
|
||||
T2 temp = Converter::convert(temp_f);
|
||||
data[i] = temp.x;
|
||||
data[i + 1] = temp.y;
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; ++i) {
|
||||
float temp = Converter::convert(data[i]) * scale;
|
||||
data[i] = Converter::convert(temp);
|
||||
}
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
__device__ float sum_squares() const {
|
||||
float result = 0.0f;
|
||||
if constexpr (width % 2 == 0) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; i += 2) {
|
||||
float2 z = Converter::convert(T2{data[i], data[i + 1]});
|
||||
result += z.x * z.x + z.y * z.y;
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < width; ++i) {
|
||||
float x = Converter::convert(data[i]);
|
||||
result += x * x;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
} // namespace vllm
|
33
tests/compile/backend.py
Normal file
33
tests/compile/backend.py
Normal file
@ -0,0 +1,33 @@
|
||||
from copy import deepcopy
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class TestBackend:
|
||||
"""
|
||||
This class provides a simple Inductor backend that can be used for testing.
|
||||
It takes a list of custom passes and runs them after Inductor's passes.
|
||||
It also saves the graph before and after the custom passes for inspection.
|
||||
"""
|
||||
|
||||
def __init__(self, *args: Callable[[torch.fx.Graph], None]):
|
||||
self.custom_passes = args
|
||||
from torch._inductor import config
|
||||
self.current_config = config.shallow_copy_dict()
|
||||
self.current_config['post_grad_custom_post_pass'] = self.post_pass
|
||||
|
||||
def __call__(self, graph: torch.fx.GraphModule, example_inputs):
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
return compile_fx(graph,
|
||||
example_inputs,
|
||||
config_patches=self.current_config)
|
||||
|
||||
def post_pass(self, graph: torch.fx.Graph):
|
||||
self.graph_pre_pass = deepcopy(graph)
|
||||
for pass_ in self.custom_passes:
|
||||
pass_(graph)
|
||||
|
||||
self.graph_post_pass = deepcopy(graph)
|
||||
# assign by reference, will reflect the final state of the graph
|
||||
self.final_graph = graph
|
92
tests/compile/test_fusion.py
Normal file
92
tests/compile/test_fusion.py
Normal file
@ -0,0 +1,92 @@
|
||||
import pytest
|
||||
import torch
|
||||
from compressed_tensors.quantization import FP8_DTYPE
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.config import CompilationConfig
|
||||
from vllm.compilation.fusion import (FusionPass, find_auto_fn,
|
||||
find_auto_fn_maybe)
|
||||
from vllm.compilation.reshapes import RedundantReshapesPass
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
apply_fp8_linear)
|
||||
|
||||
from .backend import TestBackend
|
||||
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, eps: float, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
|
||||
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(4)]
|
||||
self.w = [
|
||||
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||
for _ in range(2)
|
||||
]
|
||||
|
||||
def forward(self, x):
|
||||
resid = torch.relu(x)
|
||||
y = self.norm[0](x)
|
||||
|
||||
x2 = apply_fp8_linear(y, self.w[0], self.scale[0], self.scale[1])
|
||||
# make sure resid is used for replacement to work
|
||||
y2, resid = self.norm[1](x2, resid)
|
||||
|
||||
x3 = apply_fp8_linear(y2, self.w[1], self.scale[2], self.scale[3])
|
||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||
return y3
|
||||
|
||||
|
||||
# Init does pattern registration, which can only happen once
|
||||
config = CompilationConfig(enable_fusion=True)
|
||||
reshape_pass = RedundantReshapesPass(config)
|
||||
fusion_pass = FusionPass.instance(config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096])
|
||||
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
|
||||
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
|
||||
reason="Only test on CUDA")
|
||||
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
||||
if eps != 1e-5:
|
||||
pytest.skip("Only test eps=1e-5 for now")
|
||||
|
||||
# Reshape pass is needed for the fusion pass to work
|
||||
backend = TestBackend(reshape_pass, fusion_pass)
|
||||
model = TestModel(hidden_size, eps)
|
||||
|
||||
# First dimension dynamic
|
||||
x = torch.rand(num_tokens, hidden_size)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
|
||||
result = model(x)
|
||||
|
||||
model2 = torch.compile(model, backend=backend)
|
||||
result2 = model2(x)
|
||||
|
||||
# Check that it gives the same answer
|
||||
torch.testing.assert_close(result, result2, atol=1e-3, rtol=1e-3)
|
||||
|
||||
# Check substitution worked
|
||||
pre_nodes = backend.graph_pre_pass.nodes
|
||||
post_nodes = backend.graph_post_pass.nodes
|
||||
|
||||
rms_quant = torch.ops._C.rms_norm_static_fp8_quant.default
|
||||
add_rms_quant = torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
|
||||
fp8_quant = torch.ops._C.static_scaled_fp8_quant.default
|
||||
|
||||
# In pre-nodes, fp8 quant should be present and fused kernels should not
|
||||
assert find_auto_fn_maybe(pre_nodes, rms_quant) is None
|
||||
assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None
|
||||
find_auto_fn(pre_nodes, fp8_quant)
|
||||
|
||||
# In post-nodes, fused kernels should be present and fp8 quant should not
|
||||
find_auto_fn(post_nodes, rms_quant)
|
||||
find_auto_fn(post_nodes, add_rms_quant)
|
||||
assert find_auto_fn_maybe(post_nodes, fp8_quant) is None
|
@ -1,13 +1,14 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.quant_utils import FP8_DTYPE
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
|
||||
HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192,
|
||||
HIDDEN_SIZES = [8, 768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192,
|
||||
8199] # Arbitrary values for testing
|
||||
ADD_RESIDUAL = [False, True]
|
||||
SEEDS = [0]
|
||||
@ -59,3 +60,75 @@ def test_rms_norm(
|
||||
else:
|
||||
opcheck(torch.ops._C.rms_norm,
|
||||
(out, x, layer.weight.data, layer.variance_epsilon))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("quant_scale", [1.0, 0.01, 10.0])
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_fused_rms_norm_quant(
|
||||
num_tokens: int,
|
||||
hidden_size: int,
|
||||
add_residual: bool,
|
||||
dtype: torch.dtype,
|
||||
quant_scale: float,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
weight = torch.empty(hidden_size, dtype=dtype).normal_(mean=1.0, std=0.1)
|
||||
scale = 1 / (2 * hidden_size)
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
x *= scale
|
||||
if add_residual:
|
||||
residual = torch.randn_like(x) * scale
|
||||
residual_fused = residual.clone()
|
||||
else:
|
||||
residual = residual_fused = None
|
||||
|
||||
out_norm = torch.empty_like(x)
|
||||
out_quant = torch.empty_like(x, dtype=FP8_DTYPE)
|
||||
out_quant_fused = torch.empty_like(out_quant)
|
||||
|
||||
quant_scale_t = torch.tensor(quant_scale, dtype=torch.float32)
|
||||
|
||||
if add_residual:
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant(
|
||||
out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6)
|
||||
|
||||
# Unfused kernel is in-place so it goes second
|
||||
# Also use a separate clone of x to avoid modifying the input
|
||||
x_unfused = x.clone()
|
||||
torch.ops._C.fused_add_rms_norm(x_unfused, residual, weight, 1e-6)
|
||||
torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused,
|
||||
quant_scale_t)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(residual_fused,
|
||||
residual,
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
|
||||
opcheck(
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant,
|
||||
(out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6))
|
||||
else:
|
||||
torch.ops._C.rms_norm_static_fp8_quant(out_quant_fused, x, weight,
|
||||
quant_scale_t, 1e-6)
|
||||
|
||||
torch.ops._C.rms_norm(out_norm, x, weight, 1e-6)
|
||||
torch.ops._C.static_scaled_fp8_quant(out_quant, out_norm,
|
||||
quant_scale_t)
|
||||
|
||||
opcheck(torch.ops._C.rms_norm_static_fp8_quant,
|
||||
(out_quant_fused, x, weight, quant_scale_t, 1e-6))
|
||||
|
||||
torch.testing.assert_close(out_quant_fused.to(dtype=torch.float32),
|
||||
out_quant.to(dtype=torch.float32),
|
||||
atol=1e-3,
|
||||
rtol=1e-3)
|
||||
|
@ -2,7 +2,8 @@ import copy
|
||||
import dataclasses
|
||||
import operator
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Tuple,
|
||||
Union)
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@ -10,11 +11,13 @@ import torch.fx as fx
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import weak_ref_tensors
|
||||
from vllm.utils import combine_fx_passes, weak_ref_tensors
|
||||
|
||||
from .config import CompilationConfig
|
||||
from .counter import compilation_counter
|
||||
from .fusion import FusionPass
|
||||
from .levels import CompilationLevel
|
||||
from .reshapes import RedundantReshapesPass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -99,28 +102,74 @@ def fix_functionalization(graph: fx.Graph):
|
||||
user.replace_all_uses_with(replace_node)
|
||||
nodes_to_remove.append(user)
|
||||
nodes_to_remove.append(node)
|
||||
elif (node.args[0] ==
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default):
|
||||
# manual replace for fused_add_rms_norm_static_fp8_quant
|
||||
# this is the most effective optimization for llama
|
||||
# failing to do this will result in many unnecessary copies
|
||||
|
||||
kwargs = node.kwargs
|
||||
|
||||
result = kwargs['result']
|
||||
residual = kwargs['residual']
|
||||
|
||||
# Create a new call to
|
||||
# torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
|
||||
with graph.inserting_before(node):
|
||||
# just insert the call to the custom op
|
||||
# NOTE: don't run dead code elimination,
|
||||
# otherwise this op will be removed
|
||||
graph.call_function(
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant.
|
||||
default,
|
||||
kwargs=kwargs)
|
||||
|
||||
for user in list(node.users):
|
||||
if user.op == 'call_function' and user.target == operator.getitem: # noqa
|
||||
# Remove the getitem node
|
||||
if user.args[1] == 1:
|
||||
replace_node = result
|
||||
elif user.args[1] == 2:
|
||||
replace_node = residual
|
||||
user.replace_all_uses_with(replace_node)
|
||||
nodes_to_remove.append(user)
|
||||
nodes_to_remove.append(node)
|
||||
|
||||
elif node.args[0] == torch.ops._C.rms_norm.default:
|
||||
# manual replace for rms_norm
|
||||
|
||||
kwargs = node.kwargs
|
||||
|
||||
input = kwargs['input']
|
||||
out = kwargs['out']
|
||||
weight = kwargs['weight']
|
||||
epsilon = kwargs['epsilon']
|
||||
# Create a new call to torch.ops._C.rotary_embedding.default
|
||||
# cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa
|
||||
replace_node = kwargs['result']
|
||||
# Create a new call to torch.ops._C.rms_norm.default
|
||||
with graph.inserting_before(node):
|
||||
# just insert the call to the custom op
|
||||
# NOTE: don't run dead code elimination,
|
||||
# otherwise this op will be removed
|
||||
graph.call_function(torch.ops._C.rms_norm.default,
|
||||
kwargs=kwargs)
|
||||
|
||||
for user in list(node.users):
|
||||
if user.op == 'call_function' and user.target == operator.getitem: # noqa
|
||||
user.replace_all_uses_with(replace_node)
|
||||
nodes_to_remove.append(user)
|
||||
nodes_to_remove.append(node)
|
||||
|
||||
elif node.args[
|
||||
0] == torch.ops._C.rms_norm_static_fp8_quant.default: # noqa
|
||||
# manual replace for rms_norm_static_fp8_quant
|
||||
|
||||
kwargs = node.kwargs
|
||||
|
||||
replace_node = kwargs['result']
|
||||
# Create a new call to torch.ops._C.rms_norm_static_fp8_quant.default # noqa
|
||||
with graph.inserting_before(node):
|
||||
# just insert the call to the custom op
|
||||
# NOTE: don't run dead code elimination,
|
||||
# otherwise this op will be removed
|
||||
graph.call_function(
|
||||
torch.ops._C.rms_norm.default,
|
||||
args=(out, input, weight, epsilon),
|
||||
)
|
||||
|
||||
replace_node = out
|
||||
torch.ops._C.rms_norm_static_fp8_quant.default,
|
||||
kwargs=kwargs)
|
||||
|
||||
for user in list(node.users):
|
||||
if user.op == 'call_function' and user.target == operator.getitem: # noqa
|
||||
@ -136,7 +185,7 @@ def fix_functionalization(graph: fx.Graph):
|
||||
input = kwargs['input']
|
||||
out = kwargs['out']
|
||||
|
||||
# Create a new call to torch.ops._C.rotary_embedding.default
|
||||
# Create a new call to torch.ops._C.silu_and_mul.default
|
||||
# cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa
|
||||
with graph.inserting_before(node):
|
||||
# just insert the call to the custom op
|
||||
@ -319,6 +368,13 @@ class VllmBackend:
|
||||
|
||||
The major work of this backend is to split the graph into
|
||||
piecewise graphs, and pass them to the piecewise backend.
|
||||
|
||||
This backend also handles custom passes and adds them to Inductor config.
|
||||
The order of the post-grad post-passes is:
|
||||
1. post_grad_passes (constructor parameter)
|
||||
2. config["post_grad_custom_post_pass"]
|
||||
3. fix_functionalization
|
||||
This way, all passes operate on a functionalized graph.
|
||||
"""
|
||||
|
||||
compilation_configs: CompilationConfig
|
||||
@ -330,8 +386,10 @@ class VllmBackend:
|
||||
split_gm: fx.GraphModule
|
||||
piecewise_graphs: List[SplitItem]
|
||||
returned_callable: Callable
|
||||
# Inductor passes to run on the graph pre-defunctionalization
|
||||
post_grad_passes: Sequence[Callable]
|
||||
|
||||
def __init__(self, ):
|
||||
def __init__(self, post_grad_passes: Sequence[Callable] = ()):
|
||||
global global_graph_pool
|
||||
if global_graph_pool is None:
|
||||
global_graph_pool = torch.cuda.graph_pool_handle()
|
||||
@ -340,10 +398,30 @@ class VllmBackend:
|
||||
# streams, it might not be safe to share a global pool.
|
||||
# only investigate this when we use multiple streams
|
||||
self.graph_pool = global_graph_pool
|
||||
self.post_grad_passes = post_grad_passes
|
||||
|
||||
# `torch.compile` is JIT compiled, so we don't need to
|
||||
# do anything here
|
||||
|
||||
def add_passes_to_config(self):
|
||||
config = self.compilation_configs
|
||||
passes = list(self.post_grad_passes)
|
||||
|
||||
passes = passes + [RedundantReshapesPass(config)]
|
||||
|
||||
if config.enable_fusion:
|
||||
passes = passes + [FusionPass.instance(config)]
|
||||
|
||||
inductor_config = config.inductor_compile_config
|
||||
if "post_grad_custom_post_pass" in inductor_config:
|
||||
passes = passes + [inductor_config["post_grad_custom_post_pass"]]
|
||||
|
||||
# add the fix_functionalization pass last, so that all other
|
||||
# passes operate on a functionalized graph
|
||||
passes = passes + [fix_functionalization]
|
||||
combined_pass = combine_fx_passes(passes)
|
||||
inductor_config["post_grad_custom_post_pass"] = combined_pass
|
||||
|
||||
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
|
||||
|
||||
compilation_counter.num_graphs_seen += 1
|
||||
@ -357,6 +435,7 @@ class VllmBackend:
|
||||
# we get the sizes to capture for cudagraph
|
||||
# from compilation context
|
||||
self.compilation_configs = CompilationConfig.select_and_init_config()
|
||||
self.add_passes_to_config()
|
||||
|
||||
self.split_gm, self.piecewise_graphs = split_graph(
|
||||
graph, self.compilation_configs.non_cudagraph_ops)
|
||||
|
@ -1,4 +1,5 @@
|
||||
import copy
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
@ -50,6 +51,12 @@ class CompilationConfig(BaseModel):
|
||||
name because the config uses json format. If we pass the config
|
||||
from Python, functions can also be passed directly via Python object
|
||||
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`
|
||||
- Custom inductor passes:
|
||||
- dump_graph_stages: list of stages for which we want to dump the graph.
|
||||
Each pass defines its own stages (before, after, maybe in-between).
|
||||
- dump_graph_dir: directory to dump the graph. Default is .
|
||||
- enable_fusion: whether to enable the custom fusion pass.
|
||||
TODO better pass enabling system.
|
||||
|
||||
Why we have different sizes for cudagraph and inductor:
|
||||
- cudagraph: a cudagraph captured for a specific size can only be used
|
||||
@ -72,6 +79,10 @@ class CompilationConfig(BaseModel):
|
||||
cudagraph_num_of_warmups: int = 0
|
||||
cudagraph_capture_sizes: Optional[List[int]] = None
|
||||
|
||||
dump_graph_stages: List[str] = Field(default_factory=list)
|
||||
dump_graph_dir: Path = Field(default=Path("."))
|
||||
enable_fusion: bool = True
|
||||
|
||||
# not configurable, computed after init
|
||||
compile_sizes: List[int] = PrivateAttr
|
||||
capture_sizes: List[int] = PrivateAttr
|
||||
@ -81,7 +92,7 @@ class CompilationConfig(BaseModel):
|
||||
if not isinstance(v, str):
|
||||
assert callable(v), (
|
||||
f"pass {k} should be a function or a qualified name")
|
||||
self.inductor_passes[k] = v
|
||||
self.inductor_compile_config[k] = v
|
||||
continue
|
||||
|
||||
# resolve function from qualified name
|
||||
@ -91,18 +102,6 @@ class CompilationConfig(BaseModel):
|
||||
func = __import__(module).__dict__[func_name]
|
||||
self.inductor_compile_config[k] = func
|
||||
|
||||
from vllm.compilation.backends import fix_functionalization
|
||||
from vllm.utils import combine_fx_passes
|
||||
if "post_grad_custom_post_pass" in self.inductor_compile_config:
|
||||
self.inductor_compile_config[
|
||||
"post_grad_custom_post_pass"] = combine_fx_passes(
|
||||
fix_functionalization,
|
||||
self.inductor_compile_config["post_grad_custom_post_pass"],
|
||||
)
|
||||
else:
|
||||
self.inductor_compile_config[
|
||||
"post_grad_custom_post_pass"] = fix_functionalization
|
||||
|
||||
def init_during_runtime(self):
|
||||
"""To complete the initialization of config,
|
||||
we need to know the compile context, which is only available
|
||||
|
291
vllm/compilation/fusion.py
Normal file
291
vllm/compilation/fusion.py
Normal file
@ -0,0 +1,291 @@
|
||||
import operator
|
||||
from typing import Iterable, List, Optional
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||
from torch._inductor.pattern_matcher import (Match, PatternMatcherPass,
|
||||
fwd_only, register_replacement)
|
||||
|
||||
from vllm.compilation.config import CompilationConfig
|
||||
from vllm.compilation.inductor_pass import InductorPass
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def rms_pattern_static(result: torch.Tensor, result_rms: torch.Tensor,
|
||||
input: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at1 = auto_functionalized(torch.ops._C.rms_norm.default,
|
||||
result=result_rms,
|
||||
input=input,
|
||||
weight=weight,
|
||||
epsilon=1e-5)
|
||||
at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default,
|
||||
result=result,
|
||||
input=at1[1],
|
||||
scale=scale)
|
||||
|
||||
# result
|
||||
return at2[1]
|
||||
|
||||
|
||||
def rms_replacement_static(result: torch.Tensor, result_rms: torch.Tensor,
|
||||
input: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at = auto_functionalized(torch.ops._C.rms_norm_static_fp8_quant.default,
|
||||
result=result,
|
||||
input=input,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=1e-5)
|
||||
|
||||
# result
|
||||
return at[1]
|
||||
|
||||
|
||||
def rms_pattern_residual_static(result: torch.Tensor, input: torch.Tensor,
|
||||
residual: torch.Tensor, weight: torch.Tensor,
|
||||
scale: torch.Tensor):
|
||||
at = auto_functionalized(torch.ops._C.fused_add_rms_norm.default,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
epsilon=1e-5)
|
||||
at1 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default,
|
||||
result=result,
|
||||
input=at[1],
|
||||
scale=scale)
|
||||
|
||||
# result, residual
|
||||
return at1[1], at[2]
|
||||
|
||||
|
||||
def rms_replacement_residual_static(result: torch.Tensor, input: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor, scale: torch.Tensor):
|
||||
at = auto_functionalized(
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default,
|
||||
result=result,
|
||||
input=input,
|
||||
residual=residual,
|
||||
weight=weight,
|
||||
scale=scale,
|
||||
epsilon=1e-5)
|
||||
# result, residual
|
||||
return at[1], at[2]
|
||||
|
||||
|
||||
def empty_bf16(*args, **kwargs):
|
||||
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
|
||||
def empty_fp8(*args, **kwargs):
|
||||
fp8 = torch.float8_e4m3fn
|
||||
return torch.empty(*args, **kwargs, dtype=fp8, device="cuda")
|
||||
|
||||
|
||||
def empty_fp32(*args, **kwargs):
|
||||
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
|
||||
|
||||
|
||||
# Utilities for post-processing multi-output matches
|
||||
def is_func(node: torch.fx.Node, target) -> bool:
|
||||
return node.op == "call_function" and node.target == target
|
||||
|
||||
|
||||
# Returns the first auto_functionalized node with the given op (if it exists)
|
||||
def find_auto_fn_maybe(nodes: Iterable[torch.fx.Node],
|
||||
op) -> Optional[torch.fx.Node]:
|
||||
for node in nodes:
|
||||
if is_func(node, auto_functionalized) and node.args[0] == op: # noqa
|
||||
return node
|
||||
return None
|
||||
|
||||
|
||||
# Returns the first auto_functionalized node with the given op
|
||||
def find_auto_fn(nodes: Iterable[torch.fx.Node], op) -> torch.fx.Node:
|
||||
node = find_auto_fn_maybe(nodes, op)
|
||||
assert node is not None, f"Could not find {op} in nodes {nodes}"
|
||||
return node
|
||||
|
||||
|
||||
# Returns the getitem node that extracts the idx-th element from node
|
||||
# (if it exists)
|
||||
def find_getitem_maybe(node: torch.fx.Node,
|
||||
idx: int) -> Optional[torch.fx.Node]:
|
||||
for user in node.users:
|
||||
if is_func(user, operator.getitem) and user.args[1] == idx:
|
||||
return user
|
||||
return None
|
||||
|
||||
|
||||
# Returns the getitem node that extracts the idx-th element from node
|
||||
def find_getitem(node: torch.fx.Node, idx: int) -> torch.fx.Node:
|
||||
ret = find_getitem_maybe(node, idx)
|
||||
assert ret is not None, f"Could not find getitem {idx} in node {node}"
|
||||
return ret
|
||||
|
||||
|
||||
class FusionPass(InductorPass):
|
||||
"""
|
||||
This pass fuses a pre-defined set of custom ops into fused ops.
|
||||
It uses the torch pattern matcher to find the patterns and replace them.
|
||||
It also manually processes multi-output matches, as those are broken in
|
||||
the torch pattern matcher.
|
||||
|
||||
Because patterns can only be registered once, the pass is a singleton.
|
||||
This will be addressed in a future version of PyTorch:
|
||||
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
|
||||
"""
|
||||
|
||||
_instance: 'Optional[FusionPass]' = None
|
||||
|
||||
@classmethod
|
||||
def instance(cls, config: CompilationConfig):
|
||||
"""
|
||||
Get the singleton instance of the FusionPass.
|
||||
If the instance exists, the config is updated but
|
||||
initialization is not repeated.
|
||||
"""
|
||||
if cls._instance is None:
|
||||
cls._instance = FusionPass(config)
|
||||
else:
|
||||
cls._instance.config = config
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, config: CompilationConfig):
|
||||
assert self.__class__._instance is None, \
|
||||
"FusionPass singleton instance already exists"
|
||||
super().__init__(config)
|
||||
|
||||
self.matches: List[Match] = []
|
||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||
pass_name="fusion_pass")
|
||||
|
||||
# Fuse rms_norm + static_scaled_fp8_quant into
|
||||
# rms_norm_static_fp8_quant
|
||||
inputs = [
|
||||
empty_fp8(5, 4),
|
||||
empty_bf16(5, 4),
|
||||
empty_bf16(5, 4),
|
||||
empty_bf16(1, 5),
|
||||
empty_fp32(1, 1)
|
||||
]
|
||||
register_replacement(rms_pattern_static, rms_replacement_static,
|
||||
inputs, fwd_only, self.patterns)
|
||||
|
||||
# Fuse fused_add_rms_norm + static_scaled_fp8_quant into
|
||||
# fused_add_rms_norm_static_fp8_quant
|
||||
# Because pattern has 2 outputs, we need to manually process the match
|
||||
# (see process_matches)
|
||||
inputs = [
|
||||
empty_fp8(5, 4),
|
||||
empty_bf16(5, 4),
|
||||
empty_bf16(5, 4),
|
||||
empty_bf16(1, 5),
|
||||
empty_fp32(1, 1)
|
||||
]
|
||||
register_replacement(rms_pattern_residual_static,
|
||||
rms_replacement_residual_static,
|
||||
inputs,
|
||||
fwd_only,
|
||||
self.patterns,
|
||||
extra_check=lambda m: self.record_match(m))
|
||||
|
||||
def record_match(self, match: Match) -> bool:
|
||||
# Hijack the extra_check to record the match and
|
||||
# save it for post-processing.
|
||||
self.matches.append(match)
|
||||
|
||||
# Return False to prevent automatic replacement.
|
||||
return False
|
||||
|
||||
def process_matches(self, graph: torch.fx.Graph):
|
||||
"""
|
||||
Manually process multi-output matches and replace them with fused nodes.
|
||||
This is necessary because the automatic replacement for multi-output
|
||||
matches is broken: https://github.com/pytorch/pytorch/issues/137280
|
||||
"""
|
||||
for match in self.matches:
|
||||
# To avoid use-before-definition errors, insert replacement nodes
|
||||
# after the last node in the match.
|
||||
# match.nodes is not guaranteed to be sorted.
|
||||
# Find the last node in the match.
|
||||
for last_node_in_match in reversed(graph.nodes):
|
||||
if last_node_in_match in match.nodes:
|
||||
break
|
||||
else:
|
||||
raise ValueError("No nodes in graph")
|
||||
|
||||
# Insert a new auto_functionalized node for the fused operation,
|
||||
# as well as getitem nodes to extract the result and residual.
|
||||
# The auto_functionalized node returns a tuple of
|
||||
# (None, result, residual) - None is the function return value.
|
||||
# The resulting graph looks like this:
|
||||
# at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa
|
||||
# result_node_new = at[1]
|
||||
# residual_node_new = at[2]
|
||||
with graph.inserting_after(last_node_in_match):
|
||||
kwargs = match.kwargs
|
||||
kwargs["epsilon"] = 1e-5 # Currently hard-coded in RMSNorm
|
||||
|
||||
fused_node = graph.call_function(
|
||||
auto_functionalized,
|
||||
(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default,
|
||||
),
|
||||
kwargs=kwargs)
|
||||
|
||||
graph.inserting_after(fused_node)
|
||||
result_node_new = graph.call_function(operator.getitem,
|
||||
(fused_node, 1))
|
||||
residual_node_new = graph.call_function(
|
||||
operator.getitem, (fused_node, 2))
|
||||
|
||||
# Last part of replacement is rebinding the users of nodes in the
|
||||
# match to use the new nodes.
|
||||
|
||||
# Find the nodes in the match that we need to rebind
|
||||
rms_node = find_auto_fn(match.nodes,
|
||||
torch.ops._C.fused_add_rms_norm.default)
|
||||
quant_node = find_auto_fn(
|
||||
match.nodes, torch.ops._C.static_scaled_fp8_quant.default)
|
||||
|
||||
assert len(rms_node.users) == 2
|
||||
assert len(quant_node.users) == 1
|
||||
|
||||
# meta["val"] is used by de-functionalization and has to contain the
|
||||
# value of the node (tuple of tensors) that would be returned by the
|
||||
# functionalized node during tracing.
|
||||
|
||||
rms_tup = rms_node.meta["val"]
|
||||
quant_tup = quant_node.meta["val"]
|
||||
|
||||
# The result of fused_node must be a tuple with the first element
|
||||
# None (the function return value) and the remaining elements
|
||||
# representing the mutated inputs.
|
||||
fused_tup = (None, quant_tup[1], rms_tup[1], rms_tup[2])
|
||||
fused_node.meta["val"] = fused_tup
|
||||
|
||||
# Find the getitem nodes and replace their uses with the new nodes.
|
||||
# The old nodes will be removed by DCE at the end of the pass.
|
||||
find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new)
|
||||
find_getitem(quant_node, 1).replace_all_uses_with(result_node_new)
|
||||
|
||||
# Finally, remove matched nodes
|
||||
graph.eliminate_dead_code()
|
||||
assert all(node not in graph.nodes for match in self.matches
|
||||
for node in match.nodes)
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.dump_graph(graph, "before_fusion")
|
||||
|
||||
count = self.patterns.apply(graph)
|
||||
logger.info("Replaced %s patterns", count)
|
||||
self.dump_graph(graph, "after_pattern_match")
|
||||
|
||||
# Manually process multi-output matches (and run DCE)
|
||||
self.process_matches(graph)
|
||||
logger.info("Post-processed %s matches", len(self.matches))
|
||||
self.dump_graph(graph, "after_fusion")
|
||||
self.matches.clear()
|
38
vllm/compilation/inductor_pass.py
Normal file
38
vllm/compilation/inductor_pass.py
Normal file
@ -0,0 +1,38 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.compilation.config import CompilationConfig
|
||||
# yapf: disable
|
||||
from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_world_size as get_tp_world_size)
|
||||
from vllm.distributed import model_parallel_is_initialized as p_is_init
|
||||
# yapf: enable
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class InductorPass(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self, config: CompilationConfig):
|
||||
self.config = config
|
||||
|
||||
def dump_graph(self, graph: torch.fx.Graph, stage: str):
|
||||
if stage in self.config.dump_graph_stages:
|
||||
# Make sure filename includes rank in the distributed setting
|
||||
parallel = p_is_init() and get_tp_world_size() > 1
|
||||
rank = f"-{get_tp_rank()}" if parallel else ""
|
||||
filepath = self.config.dump_graph_dir / f"{stage}{rank}.py"
|
||||
|
||||
logger.info("Printing graph to %s", filepath)
|
||||
with open(filepath, "w") as f:
|
||||
src = graph.python_code(root_module="self", verbose=True).src
|
||||
# Add imports so it's not full of errors
|
||||
print("import torch; from torch import device", file=f)
|
||||
print(src, file=f)
|
85
vllm/compilation/reshapes.py
Normal file
85
vllm/compilation/reshapes.py
Normal file
@ -0,0 +1,85 @@
|
||||
from typing import Union
|
||||
|
||||
import torch.fx
|
||||
from torch import SymInt
|
||||
|
||||
from vllm.compilation.fusion import is_func
|
||||
from vllm.compilation.inductor_pass import InductorPass
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RedundantReshapesPass(InductorPass):
|
||||
"""
|
||||
This is an inductor pass that removes redundant reshape operations.
|
||||
It is required for RMSNorm-quant fusion to work properly.
|
||||
That's because apply_fp8_linear adds a reshape, which is redundant
|
||||
in the 2D-case.
|
||||
|
||||
Example graph:
|
||||
|
||||
getitem_1: "f16[s0, 4096]" = ...
|
||||
view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096])
|
||||
at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...)
|
||||
out: "f8e4m3fn[s0, 4096]" = at[1]
|
||||
|
||||
Can be replaced with:
|
||||
getitem_1: "f16[s0, 4096]" = ...
|
||||
at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...)
|
||||
out: "f8e4m3fn[s0, 4096]" = at[1]
|
||||
"""
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.dump_graph(graph, "before_reshapes")
|
||||
count = 0
|
||||
# Remove no-op reshapes/views:
|
||||
for node in graph.nodes:
|
||||
if is_func(node, torch.ops.aten.reshape.default):
|
||||
input, shape = node.args[:2]
|
||||
input_shape = input.meta["val"].shape
|
||||
if len(shape) != len(input_shape):
|
||||
# Reshape changing rank, skip
|
||||
continue
|
||||
|
||||
if shape.count(-1) > 1:
|
||||
# Invalid reshape args, skip
|
||||
continue
|
||||
|
||||
if all(
|
||||
self.dims_equivalent(s, i_s)
|
||||
for s, i_s in zip(shape, input_shape)):
|
||||
node.replace_all_uses_with(input)
|
||||
graph.erase_node(node)
|
||||
count += 1
|
||||
|
||||
logger.info("Removed %s no-op reshapes", count)
|
||||
|
||||
self.dump_graph(graph, "after_reshapes")
|
||||
|
||||
def dims_equivalent(self, dim: Union[int, torch.fx.Node],
|
||||
i_dim: Union[int, SymInt]) -> bool:
|
||||
"""
|
||||
This function checks if two dimensions are equivalent.
|
||||
:param dim: The dimension arg to reshape
|
||||
:param i_dim: The corresponding dimension in the input tensor
|
||||
:return: Are the dimensions equivalent?
|
||||
|
||||
There are three cases in which the dimensions are equivalent:
|
||||
1. The dimensions are equal (both integers)
|
||||
2. The reshape dimension is -1 (i.e. inferred)
|
||||
3. The dimensions both correspond to the same SymInt
|
||||
|
||||
While case 2 does not guarantee the dimensions are equal,
|
||||
they are equal if all other dimensions are equal.
|
||||
|
||||
In case 3, the reshape dimension is a torch.fx.Node,
|
||||
and its value is a SymInt. That value is equal to the
|
||||
input dimension.
|
||||
|
||||
"""
|
||||
# Case 1 and 2
|
||||
if dim == i_dim or dim == -1:
|
||||
return True
|
||||
# Case 3
|
||||
return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim
|
@ -68,6 +68,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
|
||||
VLLM_SKIP_P2P_CHECK: bool = False
|
||||
VLLM_TORCH_COMPILE_LEVEL: int = 0
|
||||
VLLM_TORCH_COMPILE_CONFIG: Optional[str] = None
|
||||
VLLM_CUSTOM_OPS: List[str] = []
|
||||
VLLM_DISABLED_KERNELS: List[str] = []
|
||||
VLLM_USE_V1: bool = False
|
||||
@ -226,6 +227,7 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
# and disabled when running with Inductor (compile_level >= Inductor).
|
||||
"VLLM_CUSTOM_OPS":
|
||||
lambda: os.environ.get("VLLM_CUSTOM_OPS", "").replace(" ", "").split(","),
|
||||
|
||||
# local rank of the process in the distributed setting, used to determine
|
||||
# the GPU device id
|
||||
"LOCAL_RANK":
|
||||
|
Loading…
x
Reference in New Issue
Block a user