[torch.compile] Fuse RMSNorm with quant (#9138)
Signed-off-by: luka <luka@neuralmagic.com> Co-authored-by: youkaichao <youkaichao@126.com>
This commit is contained in:
parent
e1b5a82179
commit
4f93dfe952
@ -191,6 +191,7 @@ set(VLLM_EXT_SRC
|
|||||||
"csrc/pos_encoding_kernels.cu"
|
"csrc/pos_encoding_kernels.cu"
|
||||||
"csrc/activation_kernels.cu"
|
"csrc/activation_kernels.cu"
|
||||||
"csrc/layernorm_kernels.cu"
|
"csrc/layernorm_kernels.cu"
|
||||||
|
"csrc/layernorm_quant_kernels.cu"
|
||||||
"csrc/quantization/gptq/q_gemm.cu"
|
"csrc/quantization/gptq/q_gemm.cu"
|
||||||
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
|
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
|
||||||
"csrc/quantization/fp8/common.cu"
|
"csrc/quantization/fp8/common.cu"
|
||||||
|
@ -1,21 +1,13 @@
|
|||||||
#include <torch/all.h>
|
#include "type_convert.cuh"
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include "dispatch_utils.h"
|
||||||
|
|
||||||
|
#include <torch/cuda.h>
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
#include "dispatch_utils.h"
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
#include <cuda_bf16.h>
|
|
||||||
#include <cuda_fp16.h>
|
|
||||||
#include <cub/util_type.cuh>
|
|
||||||
#include <cub/cub.cuh>
|
#include <cub/cub.cuh>
|
||||||
#else
|
#else
|
||||||
#include <hip/hip_bf16.h>
|
|
||||||
#include <hip/hip_fp16.h>
|
|
||||||
#include <hipcub/util_type.hpp>
|
|
||||||
#include <hipcub/hipcub.hpp>
|
#include <hipcub/hipcub.hpp>
|
||||||
|
|
||||||
using __nv_bfloat16 = __hip_bfloat16;
|
|
||||||
using __nv_bfloat162 = __hip_bfloat162;
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
@ -51,155 +43,6 @@ __global__ void rms_norm_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Converter structs for the conversion from torch types to HIP/CUDA types,
|
|
||||||
and the associated type conversions within HIP/CUDA. These helpers need
|
|
||||||
to be implemented for now because the relevant type conversion
|
|
||||||
operators/constructors are not consistently implemented by HIP/CUDA, so
|
|
||||||
a generic conversion via type casts cannot be implemented.
|
|
||||||
|
|
||||||
Each struct should have the member static constexpr bool `exists`:
|
|
||||||
If false, the optimized kernel is not used for the corresponding torch type.
|
|
||||||
If true, the struct should be fully defined as shown in the examples below.
|
|
||||||
*/
|
|
||||||
template <typename torch_type>
|
|
||||||
struct _typeConvert {
|
|
||||||
static constexpr bool exists = false;
|
|
||||||
};
|
|
||||||
|
|
||||||
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
|
|
||||||
// CUDA < 12.0 runs into issues with packed type conversion
|
|
||||||
template <>
|
|
||||||
struct _typeConvert<c10::Half> {
|
|
||||||
static constexpr bool exists = true;
|
|
||||||
using hip_type = __half;
|
|
||||||
using packed_hip_type = __half2;
|
|
||||||
|
|
||||||
__device__ static inline float convert(hip_type x) { return __half2float(x); }
|
|
||||||
__device__ static inline float2 convert(packed_hip_type x) {
|
|
||||||
return __half22float2(x);
|
|
||||||
}
|
|
||||||
__device__ static inline hip_type convert(float x) {
|
|
||||||
return __float2half_rn(x);
|
|
||||||
}
|
|
||||||
__device__ static inline packed_hip_type convert(float2 x) {
|
|
||||||
return __float22half2_rn(x);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
|
||||||
// CUDA_ARCH < 800 does not have BF16 support
|
|
||||||
// TODO: Add in ROCm support once public headers handle bf16 maturely
|
|
||||||
template <>
|
|
||||||
struct _typeConvert<c10::BFloat16> {
|
|
||||||
static constexpr bool exists = true;
|
|
||||||
using hip_type = __nv_bfloat16;
|
|
||||||
using packed_hip_type = __nv_bfloat162;
|
|
||||||
|
|
||||||
__device__ static inline float convert(hip_type x) {
|
|
||||||
return __bfloat162float(x);
|
|
||||||
}
|
|
||||||
__device__ static inline float2 convert(packed_hip_type x) {
|
|
||||||
return __bfloat1622float2(x);
|
|
||||||
}
|
|
||||||
__device__ static inline hip_type convert(float x) {
|
|
||||||
return __float2bfloat16(x);
|
|
||||||
}
|
|
||||||
__device__ static inline packed_hip_type convert(float2 x) {
|
|
||||||
return __float22bfloat162_rn(x);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
|
||||||
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
|
|
||||||
// 12000))
|
|
||||||
|
|
||||||
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
|
|
||||||
for appropriate specializations of fused_add_rms_norm_kernel.
|
|
||||||
Only functions that are necessary in that kernel are implemented.
|
|
||||||
Alignment to 16 bytes is required to use 128-bit global memory ops.
|
|
||||||
*/
|
|
||||||
template <typename scalar_t, int width>
|
|
||||||
struct alignas(16) _f16Vec {
|
|
||||||
/* Not theoretically necessary that width is a power of 2 but should
|
|
||||||
almost always be the case for optimization purposes */
|
|
||||||
static_assert(width > 0 && (width & (width - 1)) == 0,
|
|
||||||
"Width is not a positive power of 2!");
|
|
||||||
using Converter = _typeConvert<scalar_t>;
|
|
||||||
using T1 = typename Converter::hip_type;
|
|
||||||
using T2 = typename Converter::packed_hip_type;
|
|
||||||
T1 data[width];
|
|
||||||
|
|
||||||
__device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
|
|
||||||
if constexpr (width % 2 == 0) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < width; i += 2) {
|
|
||||||
T2 temp{data[i], data[i + 1]};
|
|
||||||
temp += T2{other.data[i], other.data[i + 1]};
|
|
||||||
data[i] = temp.x;
|
|
||||||
data[i + 1] = temp.y;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < width; ++i) data[i] += other.data[i];
|
|
||||||
}
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
|
|
||||||
if constexpr (width % 2 == 0) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < width; i += 2) {
|
|
||||||
T2 temp{data[i], data[i + 1]};
|
|
||||||
temp *= T2{other.data[i], other.data[i + 1]};
|
|
||||||
data[i] = temp.x;
|
|
||||||
data[i + 1] = temp.y;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < width; ++i) data[i] *= other.data[i];
|
|
||||||
}
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ _f16Vec& operator*=(const float scale) {
|
|
||||||
if constexpr (width % 2 == 0) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < width; i += 2) {
|
|
||||||
float2 temp_f = Converter::convert(T2{data[i], data[i + 1]});
|
|
||||||
temp_f.x *= scale;
|
|
||||||
temp_f.y *= scale;
|
|
||||||
T2 temp = Converter::convert(temp_f);
|
|
||||||
data[i] = temp.x;
|
|
||||||
data[i + 1] = temp.y;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < width; ++i) {
|
|
||||||
float temp = Converter::convert(data[i]) * scale;
|
|
||||||
data[i] = Converter::convert(temp);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ float sum_squares() const {
|
|
||||||
float result = 0.0f;
|
|
||||||
if constexpr (width % 2 == 0) {
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < width; i += 2) {
|
|
||||||
float2 z = Converter::convert(T2{data[i], data[i + 1]});
|
|
||||||
result += z.x * z.x + z.y * z.y;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
#pragma unroll
|
|
||||||
for (int i = 0; i < width; ++i) {
|
|
||||||
float x = Converter::convert(data[i]);
|
|
||||||
result += x * x;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/* Function specialization in the case of FP16/BF16 tensors.
|
/* Function specialization in the case of FP16/BF16 tensors.
|
||||||
Additional optimizations we can make in this case are
|
Additional optimizations we can make in this case are
|
||||||
packed and vectorized operations, which help with the
|
packed and vectorized operations, which help with the
|
||||||
|
234
csrc/layernorm_quant_kernels.cu
Normal file
234
csrc/layernorm_quant_kernels.cu
Normal file
@ -0,0 +1,234 @@
|
|||||||
|
/*
|
||||||
|
* This file contains the CUDA kernels for the fused quantized layernorm.
|
||||||
|
* The kernels correspond to the kernels in layernorm_kernels.cu, except they
|
||||||
|
* also produce quantized output directly.
|
||||||
|
* Currently, only static fp8 quantization is supported.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "type_convert.cuh"
|
||||||
|
#include "quantization/fp8/common.cuh"
|
||||||
|
#include "dispatch_utils.h"
|
||||||
|
|
||||||
|
#include <torch/cuda.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
#include <cub/cub.cuh>
|
||||||
|
#else
|
||||||
|
#include <hipcub/hipcub.hpp>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
// TODO(woosuk): Further optimize this kernel.
|
||||||
|
template <typename scalar_t>
|
||||||
|
__global__ void rms_norm_static_fp8_quant_kernel(
|
||||||
|
FP8_TYPE* __restrict__ out, // [..., hidden_size]
|
||||||
|
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
|
const float* __restrict__ scale, // [1]
|
||||||
|
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||||
|
__shared__ float s_variance;
|
||||||
|
float variance = 0.0f;
|
||||||
|
|
||||||
|
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||||
|
const float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||||
|
variance += x * x;
|
||||||
|
}
|
||||||
|
|
||||||
|
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||||
|
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||||
|
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// invert scale to avoid division
|
||||||
|
float const scale_inv = 1.0f / *scale;
|
||||||
|
|
||||||
|
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||||
|
float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||||
|
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
|
||||||
|
out[blockIdx.x * hidden_size + idx] =
|
||||||
|
scaled_fp8_conversion<true>(out_norm, scale_inv);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Function specialization in the case of FP16/BF16 tensors.
|
||||||
|
Additional optimizations we can make in this case are
|
||||||
|
packed and vectorized operations, which help with the
|
||||||
|
memory latency bottleneck. */
|
||||||
|
template <typename scalar_t, int width>
|
||||||
|
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
|
||||||
|
fused_add_rms_norm_static_fp8_quant_kernel(
|
||||||
|
FP8_TYPE* __restrict__ out, // [..., hidden_size]
|
||||||
|
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
|
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||||
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
|
const float* __restrict__ scale, // [1]
|
||||||
|
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||||
|
// Sanity checks on our vector struct and type-punned pointer arithmetic
|
||||||
|
static_assert(std::is_pod_v<_f16Vec<scalar_t, width>>);
|
||||||
|
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
|
||||||
|
|
||||||
|
const int vec_hidden_size = hidden_size / width;
|
||||||
|
__shared__ float s_variance;
|
||||||
|
float variance = 0.0f;
|
||||||
|
/* These and the argument pointers are all declared `restrict` as they are
|
||||||
|
not aliased in practice. Argument pointers should not be dereferenced
|
||||||
|
in this kernel as that would be undefined behavior */
|
||||||
|
auto* __restrict__ input_v =
|
||||||
|
reinterpret_cast<_f16Vec<scalar_t, width>*>(input);
|
||||||
|
auto* __restrict__ residual_v =
|
||||||
|
reinterpret_cast<_f16Vec<scalar_t, width>*>(residual);
|
||||||
|
auto* __restrict__ weight_v =
|
||||||
|
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
|
||||||
|
|
||||||
|
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||||
|
int id = blockIdx.x * vec_hidden_size + idx;
|
||||||
|
_f16Vec<scalar_t, width> temp = input_v[id];
|
||||||
|
temp += residual_v[id];
|
||||||
|
variance += temp.sum_squares();
|
||||||
|
residual_v[id] = temp;
|
||||||
|
}
|
||||||
|
|
||||||
|
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||||
|
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||||
|
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// invert scale to avoid division
|
||||||
|
float const scale_inv = 1.0f / *scale;
|
||||||
|
|
||||||
|
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||||
|
int id = blockIdx.x * vec_hidden_size + idx;
|
||||||
|
_f16Vec<scalar_t, width> temp = residual_v[id];
|
||||||
|
temp *= s_variance;
|
||||||
|
temp *= weight_v[idx];
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < width; ++i) {
|
||||||
|
out[id * width + i] =
|
||||||
|
scaled_fp8_conversion<true>(float(temp.data[i]), scale_inv);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Generic fused_add_rms_norm_kernel
|
||||||
|
The width field is not used here but necessary for other specializations.
|
||||||
|
*/
|
||||||
|
template <typename scalar_t, int width>
|
||||||
|
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
|
||||||
|
fused_add_rms_norm_static_fp8_quant_kernel(
|
||||||
|
FP8_TYPE* __restrict__ out, // [..., hidden_size]
|
||||||
|
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
|
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||||
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
|
const float* __restrict__ scale, // [1]
|
||||||
|
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||||
|
__shared__ float s_variance;
|
||||||
|
float variance = 0.0f;
|
||||||
|
|
||||||
|
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||||
|
scalar_t z = input[blockIdx.x * hidden_size + idx];
|
||||||
|
z += residual[blockIdx.x * hidden_size + idx];
|
||||||
|
float x = (float)z;
|
||||||
|
variance += x * x;
|
||||||
|
residual[blockIdx.x * hidden_size + idx] = z;
|
||||||
|
}
|
||||||
|
|
||||||
|
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||||
|
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||||
|
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
s_variance = rsqrtf(variance / hidden_size + epsilon);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// invert scale to avoid division
|
||||||
|
float const scale_inv = 1.0f / *scale;
|
||||||
|
|
||||||
|
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||||
|
float x = (float)residual[blockIdx.x * hidden_size + idx];
|
||||||
|
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
|
||||||
|
out[blockIdx.x * hidden_size + idx] =
|
||||||
|
scaled_fp8_conversion<true>(out_norm, scale_inv);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||||
|
torch::Tensor& input, // [..., hidden_size]
|
||||||
|
torch::Tensor& weight, // [hidden_size]
|
||||||
|
torch::Tensor& scale, // [1]
|
||||||
|
double epsilon) {
|
||||||
|
int hidden_size = input.size(-1);
|
||||||
|
int num_tokens = input.numel() / hidden_size;
|
||||||
|
|
||||||
|
dim3 grid(num_tokens);
|
||||||
|
dim3 block(std::min(hidden_size, 1024));
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
|
||||||
|
vllm::rms_norm_static_fp8_quant_kernel<scalar_t>
|
||||||
|
<<<grid, block, 0, stream>>>(
|
||||||
|
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
|
||||||
|
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), epsilon,
|
||||||
|
num_tokens, hidden_size);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||||
|
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
|
||||||
|
vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, width> \
|
||||||
|
<<<grid, block, 0, stream>>>( \
|
||||||
|
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(), \
|
||||||
|
residual.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \
|
||||||
|
scale.data_ptr<float>(), epsilon, num_tokens, hidden_size); \
|
||||||
|
});
|
||||||
|
|
||||||
|
void fused_add_rms_norm_static_fp8_quant(
|
||||||
|
torch::Tensor& out, // [..., hidden_size],
|
||||||
|
torch::Tensor& input, // [..., hidden_size]
|
||||||
|
torch::Tensor& residual, // [..., hidden_size]
|
||||||
|
torch::Tensor& weight, // [hidden_size]
|
||||||
|
torch::Tensor& scale, // [1]
|
||||||
|
double epsilon) {
|
||||||
|
int hidden_size = input.size(-1);
|
||||||
|
int num_tokens = input.numel() / hidden_size;
|
||||||
|
|
||||||
|
dim3 grid(num_tokens);
|
||||||
|
/* This kernel is memory-latency bound in many scenarios.
|
||||||
|
When num_tokens is large, a smaller block size allows
|
||||||
|
for increased block occupancy on CUs and better latency
|
||||||
|
hiding on global mem ops. */
|
||||||
|
const int max_block_size = (num_tokens < 256) ? 1024 : 256;
|
||||||
|
dim3 block(std::min(hidden_size, max_block_size));
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
/*If the tensor types are FP16/BF16, try to use the optimized kernel
|
||||||
|
with packed + vectorized ops.
|
||||||
|
Max optimization is achieved with a width-8 vector of FP16/BF16s
|
||||||
|
since we can load at most 128 bits at once in a global memory op.
|
||||||
|
However, this requires each tensor's data to be aligned to 16
|
||||||
|
bytes.
|
||||||
|
*/
|
||||||
|
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
||||||
|
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
|
||||||
|
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
||||||
|
bool ptrs_are_aligned =
|
||||||
|
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
|
||||||
|
if (ptrs_are_aligned && hidden_size % 8 == 0) {
|
||||||
|
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||||
|
} else {
|
||||||
|
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
||||||
|
}
|
||||||
|
}
|
10
csrc/ops.h
10
csrc/ops.h
@ -56,6 +56,16 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
|
|||||||
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
|
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
|
||||||
torch::Tensor& weight, double epsilon);
|
torch::Tensor& weight, double epsilon);
|
||||||
|
|
||||||
|
void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
||||||
|
torch::Tensor& weight, torch::Tensor& scale,
|
||||||
|
double epsilon);
|
||||||
|
|
||||||
|
void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out,
|
||||||
|
torch::Tensor& input,
|
||||||
|
torch::Tensor& residual,
|
||||||
|
torch::Tensor& weight,
|
||||||
|
torch::Tensor& scale, double epsilon);
|
||||||
|
|
||||||
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||||
torch::Tensor& key, int64_t head_size,
|
torch::Tensor& key, int64_t head_size,
|
||||||
torch::Tensor& cos_sin_cache, bool is_neox);
|
torch::Tensor& cos_sin_cache, bool is_neox);
|
||||||
|
@ -1,185 +1,16 @@
|
|||||||
#include <ATen/cuda/CUDAContext.h>
|
#include "common.cuh"
|
||||||
#include <torch/all.h>
|
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
|
||||||
|
|
||||||
#include <cmath>
|
|
||||||
|
|
||||||
#include "cuda_compat.h"
|
|
||||||
#include "dispatch_utils.h"
|
#include "dispatch_utils.h"
|
||||||
|
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
#include <cub/util_type.cuh>
|
|
||||||
#include <cub/cub.cuh>
|
#include <cub/cub.cuh>
|
||||||
#else
|
#else
|
||||||
#include <hipcub/util_type.hpp>
|
|
||||||
#include <hipcub/hipcub.hpp>
|
#include <hipcub/hipcub.hpp>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
using FP8_TYPE = c10::Float8_e4m3fn;
|
|
||||||
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
|
|
||||||
std::numeric_limits<FP8_TYPE>::max();
|
|
||||||
#else
|
|
||||||
#include "amd/hip_float8.h"
|
|
||||||
using FP8_TYPE = c10::Float8_e4m3fnuz;
|
|
||||||
// Using the default max value from pytorch (240.0) will cause accuracy
|
|
||||||
// issue when running dynamic quantization. Here use 224.0f for rocm.
|
|
||||||
constexpr auto FP8_E4M3_MAX = 224.0f;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
|
||||||
float old;
|
|
||||||
old = (value >= 0)
|
|
||||||
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
|
|
||||||
: __uint_as_float(
|
|
||||||
atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
|
||||||
|
|
||||||
return old;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <bool is_scale_inverted>
|
|
||||||
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
|
|
||||||
float const scale) {
|
|
||||||
float x = 0.0f;
|
|
||||||
if constexpr (is_scale_inverted) {
|
|
||||||
x = val * scale;
|
|
||||||
} else {
|
|
||||||
x = val / scale;
|
|
||||||
}
|
|
||||||
|
|
||||||
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
|
||||||
#ifndef USE_ROCM
|
|
||||||
return static_cast<c10::Float8_e4m3fn>(r);
|
|
||||||
#else
|
|
||||||
// Use hardware cvt instruction for fp8 on rocm
|
|
||||||
return c10::Float8_e4m3fnuz(hip_fp8(r).data,
|
|
||||||
c10::Float8_e4m3fnuz::from_bits());
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute the absolute maximum m of the input tensor and store
|
|
||||||
// m / float8_e4m3::max() in *scale. Each thread block performs a
|
|
||||||
// reduction tree and the memory in scale is atomically updated.
|
|
||||||
// So to get the right answer, *scale needs to be initialized to
|
|
||||||
// a value <= 0.0 and we need to wait for all thread blocks to
|
|
||||||
// finish before consuming *scale.
|
|
||||||
template <typename scalar_t>
|
|
||||||
__global__ void segmented_max_reduction(float* __restrict__ scale,
|
|
||||||
const scalar_t* __restrict__ input,
|
|
||||||
int64_t num_elems) {
|
|
||||||
__shared__ float cache[1024];
|
|
||||||
int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
|
|
||||||
|
|
||||||
// First store maximum for all values processes by
|
|
||||||
// the current thread in cache[threadIdx.x]
|
|
||||||
scalar_t tmp = 0.0;
|
|
||||||
while (i < num_elems) {
|
|
||||||
float x = static_cast<float>(input[i]);
|
|
||||||
tmp = max(tmp, fabs(x));
|
|
||||||
i += blockDim.x * gridDim.x;
|
|
||||||
}
|
|
||||||
cache[threadIdx.x] = tmp;
|
|
||||||
|
|
||||||
__syncthreads();
|
|
||||||
|
|
||||||
// Now perform parallel reduction within the thread block
|
|
||||||
int ib = blockDim.x / 2;
|
|
||||||
while (ib != 0) {
|
|
||||||
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
|
|
||||||
cache[threadIdx.x] = cache[threadIdx.x + ib];
|
|
||||||
}
|
|
||||||
__syncthreads();
|
|
||||||
ib /= 2;
|
|
||||||
}
|
|
||||||
// Finally, since cache[0] contains the maximum for this thread block,
|
|
||||||
// atomically write the max to the target location
|
|
||||||
if (threadIdx.x == 0) {
|
|
||||||
atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename scalar_t>
|
|
||||||
struct __align__(8) vec4_t {
|
|
||||||
scalar_t x;
|
|
||||||
scalar_t y;
|
|
||||||
scalar_t z;
|
|
||||||
scalar_t w;
|
|
||||||
};
|
|
||||||
|
|
||||||
typedef struct __align__(4) {
|
|
||||||
FP8_TYPE x;
|
|
||||||
FP8_TYPE y;
|
|
||||||
FP8_TYPE z;
|
|
||||||
FP8_TYPE w;
|
|
||||||
}
|
|
||||||
float8x4_t;
|
|
||||||
|
|
||||||
template <typename scalar_t>
|
|
||||||
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
|
|
||||||
int64_t const num_elems, int const tid,
|
|
||||||
int const step) {
|
|
||||||
// Vectorized input/output to better utilize memory bandwidth.
|
|
||||||
vec4_t<scalar_t> const* vectorized_in =
|
|
||||||
reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
|
||||||
|
|
||||||
int64_t const num_vec_elems = num_elems >> 2;
|
|
||||||
float absmax_val = 0.0f;
|
|
||||||
|
|
||||||
#pragma unroll 4
|
|
||||||
for (int64_t i = tid; i < num_vec_elems; i += step) {
|
|
||||||
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
|
||||||
absmax_val = max(absmax_val, fabs(in_vec.x));
|
|
||||||
absmax_val = max(absmax_val, fabs(in_vec.y));
|
|
||||||
absmax_val = max(absmax_val, fabs(in_vec.z));
|
|
||||||
absmax_val = max(absmax_val, fabs(in_vec.w));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle the remaining elements if num_elems is not divisible by 4
|
|
||||||
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
|
||||||
absmax_val = max(absmax_val, fabs(input[i]));
|
|
||||||
}
|
|
||||||
|
|
||||||
return absmax_val;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename scalar_t, bool is_scale_inverted>
|
|
||||||
__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
|
|
||||||
scalar_t const* __restrict__ input,
|
|
||||||
float const scale,
|
|
||||||
int64_t const num_elems,
|
|
||||||
int const tid, int const step) {
|
|
||||||
// Vectorized input/output to better utilize memory bandwidth.
|
|
||||||
vec4_t<scalar_t> const* vectorized_in =
|
|
||||||
reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
|
||||||
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
|
|
||||||
|
|
||||||
int64_t const num_vec_elems = num_elems >> 2;
|
|
||||||
|
|
||||||
#pragma unroll 4
|
|
||||||
for (int64_t i = tid; i < num_vec_elems; i += step) {
|
|
||||||
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
|
||||||
float8x4_t out_vec;
|
|
||||||
|
|
||||||
out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
|
|
||||||
static_cast<float>(in_vec.x), scale);
|
|
||||||
out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
|
|
||||||
static_cast<float>(in_vec.y), scale);
|
|
||||||
out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
|
|
||||||
static_cast<float>(in_vec.z), scale);
|
|
||||||
out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
|
|
||||||
static_cast<float>(in_vec.w), scale);
|
|
||||||
vectorized_out[i] = out_vec;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle the remaining elements if num_elems is not divisible by 4
|
|
||||||
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
|
||||||
out[i] = scaled_fp8_conversion<is_scale_inverted>(
|
|
||||||
static_cast<float>(input[i]), scale);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
|
__global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
|
||||||
const scalar_t* __restrict__ input,
|
const scalar_t* __restrict__ input,
|
||||||
|
172
csrc/quantization/fp8/common.cuh
Normal file
172
csrc/quantization/fp8/common.cuh
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
#include <c10/util/Float8_e4m3fn.h>
|
||||||
|
using FP8_TYPE = c10::Float8_e4m3fn;
|
||||||
|
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
|
||||||
|
std::numeric_limits<FP8_TYPE>::max();
|
||||||
|
#else
|
||||||
|
#include <c10/util/Float8_e4m3fnuz.h>
|
||||||
|
#include "amd/hip_float8.h"
|
||||||
|
using FP8_TYPE = c10::Float8_e4m3fnuz;
|
||||||
|
// Using the default max value from pytorch (240.0) will cause accuracy
|
||||||
|
// issue when running dynamic quantization. Here use 224.0f for rocm.
|
||||||
|
constexpr auto FP8_E4M3_MAX = 224.0f;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||||
|
float old;
|
||||||
|
old = (value >= 0)
|
||||||
|
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
|
||||||
|
: __uint_as_float(
|
||||||
|
atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
||||||
|
|
||||||
|
return old;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool is_scale_inverted>
|
||||||
|
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
|
||||||
|
float const scale) {
|
||||||
|
float x = 0.0f;
|
||||||
|
if constexpr (is_scale_inverted) {
|
||||||
|
x = val * scale;
|
||||||
|
} else {
|
||||||
|
x = val / scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
return static_cast<c10::Float8_e4m3fn>(r);
|
||||||
|
#else
|
||||||
|
// Use hardware cvt instruction for fp8 on rocm
|
||||||
|
return c10::Float8_e4m3fnuz(hip_fp8(r).data,
|
||||||
|
c10::Float8_e4m3fnuz::from_bits());
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the absolute maximum m of the input tensor and store
|
||||||
|
// m / float8_e4m3::max() in *scale. Each thread block performs a
|
||||||
|
// reduction tree and the memory in scale is atomically updated.
|
||||||
|
// So to get the right answer, *scale needs to be initialized to
|
||||||
|
// a value <= 0.0 and we need to wait for all thread blocks to
|
||||||
|
// finish before consuming *scale.
|
||||||
|
template <typename scalar_t>
|
||||||
|
__global__ void segmented_max_reduction(float* __restrict__ scale,
|
||||||
|
const scalar_t* __restrict__ input,
|
||||||
|
int64_t num_elems) {
|
||||||
|
__shared__ float cache[1024];
|
||||||
|
int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
|
// First store maximum for all values processes by
|
||||||
|
// the current thread in cache[threadIdx.x]
|
||||||
|
scalar_t tmp = 0.0;
|
||||||
|
while (i < num_elems) {
|
||||||
|
float x = static_cast<float>(input[i]);
|
||||||
|
tmp = max(tmp, fabs(x));
|
||||||
|
i += blockDim.x * gridDim.x;
|
||||||
|
}
|
||||||
|
cache[threadIdx.x] = tmp;
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Now perform parallel reduction within the thread block
|
||||||
|
int ib = blockDim.x / 2;
|
||||||
|
while (ib != 0) {
|
||||||
|
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
|
||||||
|
cache[threadIdx.x] = cache[threadIdx.x + ib];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
ib /= 2;
|
||||||
|
}
|
||||||
|
// Finally, since cache[0] contains the maximum for this thread block,
|
||||||
|
// atomically write the max to the target location
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
struct __align__(8) vec4_t {
|
||||||
|
scalar_t x;
|
||||||
|
scalar_t y;
|
||||||
|
scalar_t z;
|
||||||
|
scalar_t w;
|
||||||
|
};
|
||||||
|
|
||||||
|
typedef struct __align__(4) {
|
||||||
|
FP8_TYPE x;
|
||||||
|
FP8_TYPE y;
|
||||||
|
FP8_TYPE z;
|
||||||
|
FP8_TYPE w;
|
||||||
|
}
|
||||||
|
float8x4_t;
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
|
||||||
|
int64_t const num_elems, int const tid,
|
||||||
|
int const step) {
|
||||||
|
// Vectorized input/output to better utilize memory bandwidth.
|
||||||
|
vec4_t<scalar_t> const* vectorized_in =
|
||||||
|
reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
||||||
|
|
||||||
|
int64_t const num_vec_elems = num_elems >> 2;
|
||||||
|
float absmax_val = 0.0f;
|
||||||
|
|
||||||
|
#pragma unroll 4
|
||||||
|
for (int64_t i = tid; i < num_vec_elems; i += step) {
|
||||||
|
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
||||||
|
absmax_val = max(absmax_val, fabs(in_vec.x));
|
||||||
|
absmax_val = max(absmax_val, fabs(in_vec.y));
|
||||||
|
absmax_val = max(absmax_val, fabs(in_vec.z));
|
||||||
|
absmax_val = max(absmax_val, fabs(in_vec.w));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle the remaining elements if num_elems is not divisible by 4
|
||||||
|
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
||||||
|
absmax_val = max(absmax_val, fabs(input[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
return absmax_val;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t, bool is_scale_inverted>
|
||||||
|
__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
|
||||||
|
scalar_t const* __restrict__ input,
|
||||||
|
float const scale,
|
||||||
|
int64_t const num_elems,
|
||||||
|
int const tid, int const step) {
|
||||||
|
// Vectorized input/output to better utilize memory bandwidth.
|
||||||
|
vec4_t<scalar_t> const* vectorized_in =
|
||||||
|
reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
||||||
|
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
|
||||||
|
|
||||||
|
int64_t const num_vec_elems = num_elems >> 2;
|
||||||
|
|
||||||
|
#pragma unroll 4
|
||||||
|
for (int64_t i = tid; i < num_vec_elems; i += step) {
|
||||||
|
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
||||||
|
float8x4_t out_vec;
|
||||||
|
|
||||||
|
out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
|
||||||
|
static_cast<float>(in_vec.x), scale);
|
||||||
|
out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
|
||||||
|
static_cast<float>(in_vec.y), scale);
|
||||||
|
out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
|
||||||
|
static_cast<float>(in_vec.z), scale);
|
||||||
|
out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
|
||||||
|
static_cast<float>(in_vec.w), scale);
|
||||||
|
vectorized_out[i] = out_vec;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle the remaining elements if num_elems is not divisible by 4
|
||||||
|
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
||||||
|
out[i] = scaled_fp8_conversion<is_scale_inverted>(
|
||||||
|
static_cast<float>(input[i]), scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
@ -101,7 +101,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
// Layernorm
|
// Layernorm
|
||||||
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
||||||
ops.def(
|
ops.def(
|
||||||
"rms_norm(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
|
"rms_norm(Tensor! result, Tensor input, Tensor weight, float epsilon) -> "
|
||||||
"()");
|
"()");
|
||||||
ops.impl("rms_norm", torch::kCUDA, &rms_norm);
|
ops.impl("rms_norm", torch::kCUDA, &rms_norm);
|
||||||
|
|
||||||
@ -111,6 +111,23 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
"float epsilon) -> ()");
|
"float epsilon) -> ()");
|
||||||
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
|
ops.impl("fused_add_rms_norm", torch::kCUDA, &fused_add_rms_norm);
|
||||||
|
|
||||||
|
// Layernorm-quant
|
||||||
|
// Apply Root Mean Square (RMS) Normalization to the input tensor.
|
||||||
|
ops.def(
|
||||||
|
"rms_norm_static_fp8_quant(Tensor! result, Tensor input, Tensor weight, "
|
||||||
|
"Tensor scale, float epsilon) -> "
|
||||||
|
"()");
|
||||||
|
ops.impl("rms_norm_static_fp8_quant", torch::kCUDA,
|
||||||
|
&rms_norm_static_fp8_quant);
|
||||||
|
|
||||||
|
// In-place fused Add and RMS Normalization.
|
||||||
|
ops.def(
|
||||||
|
"fused_add_rms_norm_static_fp8_quant(Tensor! result, Tensor input, "
|
||||||
|
"Tensor! residual, Tensor weight, "
|
||||||
|
"Tensor scale, float epsilon) -> ()");
|
||||||
|
ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA,
|
||||||
|
&fused_add_rms_norm_static_fp8_quant);
|
||||||
|
|
||||||
// Rotary embedding
|
// Rotary embedding
|
||||||
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
|
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
|
||||||
ops.def(
|
ops.def(
|
||||||
@ -322,18 +339,20 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
|
|
||||||
// Compute FP8 quantized tensor for given scaling factor.
|
// Compute FP8 quantized tensor for given scaling factor.
|
||||||
ops.def(
|
ops.def(
|
||||||
"static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()");
|
"static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "
|
||||||
|
"()");
|
||||||
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
|
ops.impl("static_scaled_fp8_quant", torch::kCUDA, &static_scaled_fp8_quant);
|
||||||
|
|
||||||
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
|
// Compute dynamic-per-tensor FP8 quantized tensor and scaling factor.
|
||||||
ops.def(
|
ops.def(
|
||||||
"dynamic_scaled_fp8_quant(Tensor! out, Tensor input, Tensor! scale) -> "
|
"dynamic_scaled_fp8_quant(Tensor! result, Tensor input, Tensor! scale) "
|
||||||
|
"-> "
|
||||||
"()");
|
"()");
|
||||||
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
|
ops.impl("dynamic_scaled_fp8_quant", torch::kCUDA, &dynamic_scaled_fp8_quant);
|
||||||
|
|
||||||
// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
|
// Compute dynamic-per-token FP8 quantized tensor and scaling factor.
|
||||||
ops.def(
|
ops.def(
|
||||||
"dynamic_per_token_scaled_fp8_quant(Tensor! out, Tensor input, "
|
"dynamic_per_token_scaled_fp8_quant(Tensor! result, Tensor input, "
|
||||||
"Tensor! scale, Tensor? scale_ub) -> "
|
"Tensor! scale, Tensor? scale_ub) -> "
|
||||||
"()");
|
"()");
|
||||||
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
|
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
|
||||||
@ -341,13 +360,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
|
|
||||||
// Compute int8 quantized tensor for given scaling factor.
|
// Compute int8 quantized tensor for given scaling factor.
|
||||||
ops.def(
|
ops.def(
|
||||||
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
|
"static_scaled_int8_quant(Tensor! result, Tensor input, Tensor scale,"
|
||||||
"Tensor? azp) -> ()");
|
"Tensor? azp) -> ()");
|
||||||
ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant);
|
ops.impl("static_scaled_int8_quant", torch::kCUDA, &static_scaled_int8_quant);
|
||||||
|
|
||||||
// Compute int8 quantized tensor and scaling factor
|
// Compute int8 quantized tensor and scaling factor
|
||||||
ops.def(
|
ops.def(
|
||||||
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
|
"dynamic_scaled_int8_quant(Tensor! result, Tensor input, Tensor! scale, "
|
||||||
"Tensor!? azp) -> ()");
|
"Tensor!? azp) -> ()");
|
||||||
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
|
ops.impl("dynamic_scaled_int8_quant", torch::kCUDA,
|
||||||
&dynamic_scaled_int8_quant);
|
&dynamic_scaled_int8_quant);
|
||||||
|
165
csrc/type_convert.cuh
Normal file
165
csrc/type_convert.cuh
Normal file
@ -0,0 +1,165 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#else
|
||||||
|
#include <hip/hip_bf16.h>
|
||||||
|
#include <hip/hip_fp16.h>
|
||||||
|
|
||||||
|
using __nv_bfloat16 = __hip_bfloat16;
|
||||||
|
using __nv_bfloat162 = __hip_bfloat162;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
/* Converter structs for the conversion from torch types to HIP/CUDA types,
|
||||||
|
and the associated type conversions within HIP/CUDA. These helpers need
|
||||||
|
to be implemented for now because the relevant type conversion
|
||||||
|
operators/constructors are not consistently implemented by HIP/CUDA, so
|
||||||
|
a generic conversion via type casts cannot be implemented.
|
||||||
|
|
||||||
|
Each struct should have the member static constexpr bool `exists`:
|
||||||
|
If false, the optimized kernel is not used for the corresponding torch type.
|
||||||
|
If true, the struct should be fully defined as shown in the examples below.
|
||||||
|
*/
|
||||||
|
template <typename torch_type>
|
||||||
|
struct _typeConvert {
|
||||||
|
static constexpr bool exists = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
|
||||||
|
// CUDA < 12.0 runs into issues with packed type conversion
|
||||||
|
template <>
|
||||||
|
struct _typeConvert<c10::Half> {
|
||||||
|
static constexpr bool exists = true;
|
||||||
|
using hip_type = __half;
|
||||||
|
using packed_hip_type = __half2;
|
||||||
|
|
||||||
|
__device__ static inline float convert(hip_type x) { return __half2float(x); }
|
||||||
|
__device__ static inline float2 convert(packed_hip_type x) {
|
||||||
|
return __half22float2(x);
|
||||||
|
}
|
||||||
|
__device__ static inline hip_type convert(float x) {
|
||||||
|
return __float2half_rn(x);
|
||||||
|
}
|
||||||
|
__device__ static inline packed_hip_type convert(float2 x) {
|
||||||
|
return __float22half2_rn(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
|
// CUDA_ARCH < 800 does not have BF16 support
|
||||||
|
// TODO: Add in ROCm support once public headers handle bf16 maturely
|
||||||
|
template <>
|
||||||
|
struct _typeConvert<c10::BFloat16> {
|
||||||
|
static constexpr bool exists = true;
|
||||||
|
using hip_type = __nv_bfloat16;
|
||||||
|
using packed_hip_type = __nv_bfloat162;
|
||||||
|
|
||||||
|
__device__ static inline float convert(hip_type x) {
|
||||||
|
return __bfloat162float(x);
|
||||||
|
}
|
||||||
|
__device__ static inline float2 convert(packed_hip_type x) {
|
||||||
|
return __bfloat1622float2(x);
|
||||||
|
}
|
||||||
|
__device__ static inline hip_type convert(float x) {
|
||||||
|
return __float2bfloat16(x);
|
||||||
|
}
|
||||||
|
__device__ static inline packed_hip_type convert(float2 x) {
|
||||||
|
return __float22bfloat162_rn(x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||||
|
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
|
||||||
|
// 12000))
|
||||||
|
|
||||||
|
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
|
||||||
|
for appropriate specializations of fused_add_rms_norm_kernel.
|
||||||
|
Only functions that are necessary in that kernel are implemented.
|
||||||
|
Alignment to 16 bytes is required to use 128-bit global memory ops.
|
||||||
|
*/
|
||||||
|
template <typename scalar_t, int width>
|
||||||
|
struct alignas(16) _f16Vec {
|
||||||
|
/* Not theoretically necessary that width is a power of 2 but should
|
||||||
|
almost always be the case for optimization purposes */
|
||||||
|
static_assert(width > 0 && (width & (width - 1)) == 0,
|
||||||
|
"Width is not a positive power of 2!");
|
||||||
|
using Converter = _typeConvert<scalar_t>;
|
||||||
|
using T1 = typename Converter::hip_type;
|
||||||
|
using T2 = typename Converter::packed_hip_type;
|
||||||
|
T1 data[width];
|
||||||
|
|
||||||
|
__device__ _f16Vec& operator+=(const _f16Vec<scalar_t, width>& other) {
|
||||||
|
if constexpr (width % 2 == 0) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < width; i += 2) {
|
||||||
|
T2 temp{data[i], data[i + 1]};
|
||||||
|
temp += T2{other.data[i], other.data[i + 1]};
|
||||||
|
data[i] = temp.x;
|
||||||
|
data[i + 1] = temp.y;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < width; ++i) data[i] += other.data[i];
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ _f16Vec& operator*=(const _f16Vec<scalar_t, width>& other) {
|
||||||
|
if constexpr (width % 2 == 0) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < width; i += 2) {
|
||||||
|
T2 temp{data[i], data[i + 1]};
|
||||||
|
temp *= T2{other.data[i], other.data[i + 1]};
|
||||||
|
data[i] = temp.x;
|
||||||
|
data[i + 1] = temp.y;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < width; ++i) data[i] *= other.data[i];
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ _f16Vec& operator*=(const float scale) {
|
||||||
|
if constexpr (width % 2 == 0) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < width; i += 2) {
|
||||||
|
float2 temp_f = Converter::convert(T2{data[i], data[i + 1]});
|
||||||
|
temp_f.x *= scale;
|
||||||
|
temp_f.y *= scale;
|
||||||
|
T2 temp = Converter::convert(temp_f);
|
||||||
|
data[i] = temp.x;
|
||||||
|
data[i + 1] = temp.y;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < width; ++i) {
|
||||||
|
float temp = Converter::convert(data[i]) * scale;
|
||||||
|
data[i] = Converter::convert(temp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ float sum_squares() const {
|
||||||
|
float result = 0.0f;
|
||||||
|
if constexpr (width % 2 == 0) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < width; i += 2) {
|
||||||
|
float2 z = Converter::convert(T2{data[i], data[i + 1]});
|
||||||
|
result += z.x * z.x + z.y * z.y;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < width; ++i) {
|
||||||
|
float x = Converter::convert(data[i]);
|
||||||
|
result += x * x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace vllm
|
33
tests/compile/backend.py
Normal file
33
tests/compile/backend.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from copy import deepcopy
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class TestBackend:
|
||||||
|
"""
|
||||||
|
This class provides a simple Inductor backend that can be used for testing.
|
||||||
|
It takes a list of custom passes and runs them after Inductor's passes.
|
||||||
|
It also saves the graph before and after the custom passes for inspection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args: Callable[[torch.fx.Graph], None]):
|
||||||
|
self.custom_passes = args
|
||||||
|
from torch._inductor import config
|
||||||
|
self.current_config = config.shallow_copy_dict()
|
||||||
|
self.current_config['post_grad_custom_post_pass'] = self.post_pass
|
||||||
|
|
||||||
|
def __call__(self, graph: torch.fx.GraphModule, example_inputs):
|
||||||
|
from torch._inductor.compile_fx import compile_fx
|
||||||
|
return compile_fx(graph,
|
||||||
|
example_inputs,
|
||||||
|
config_patches=self.current_config)
|
||||||
|
|
||||||
|
def post_pass(self, graph: torch.fx.Graph):
|
||||||
|
self.graph_pre_pass = deepcopy(graph)
|
||||||
|
for pass_ in self.custom_passes:
|
||||||
|
pass_(graph)
|
||||||
|
|
||||||
|
self.graph_post_pass = deepcopy(graph)
|
||||||
|
# assign by reference, will reflect the final state of the graph
|
||||||
|
self.final_graph = graph
|
92
tests/compile/test_fusion.py
Normal file
92
tests/compile/test_fusion.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from compressed_tensors.quantization import FP8_DTYPE
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.compilation.config import CompilationConfig
|
||||||
|
from vllm.compilation.fusion import (FusionPass, find_auto_fn,
|
||||||
|
find_auto_fn_maybe)
|
||||||
|
from vllm.compilation.reshapes import RedundantReshapesPass
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
|
apply_fp8_linear)
|
||||||
|
|
||||||
|
from .backend import TestBackend
|
||||||
|
|
||||||
|
|
||||||
|
class TestModel(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, hidden_size: int, eps: float, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
|
||||||
|
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(4)]
|
||||||
|
self.w = [
|
||||||
|
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||||
|
for _ in range(2)
|
||||||
|
]
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
resid = torch.relu(x)
|
||||||
|
y = self.norm[0](x)
|
||||||
|
|
||||||
|
x2 = apply_fp8_linear(y, self.w[0], self.scale[0], self.scale[1])
|
||||||
|
# make sure resid is used for replacement to work
|
||||||
|
y2, resid = self.norm[1](x2, resid)
|
||||||
|
|
||||||
|
x3 = apply_fp8_linear(y2, self.w[1], self.scale[2], self.scale[3])
|
||||||
|
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||||
|
return y3
|
||||||
|
|
||||||
|
|
||||||
|
# Init does pattern registration, which can only happen once
|
||||||
|
config = CompilationConfig(enable_fusion=True)
|
||||||
|
reshape_pass = RedundantReshapesPass(config)
|
||||||
|
fusion_pass = FusionPass.instance(config)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||||
|
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096])
|
||||||
|
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
|
||||||
|
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
||||||
|
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
|
||||||
|
reason="Only test on CUDA")
|
||||||
|
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps):
|
||||||
|
torch.set_default_device("cuda")
|
||||||
|
torch.set_default_dtype(torch.float16)
|
||||||
|
|
||||||
|
if eps != 1e-5:
|
||||||
|
pytest.skip("Only test eps=1e-5 for now")
|
||||||
|
|
||||||
|
# Reshape pass is needed for the fusion pass to work
|
||||||
|
backend = TestBackend(reshape_pass, fusion_pass)
|
||||||
|
model = TestModel(hidden_size, eps)
|
||||||
|
|
||||||
|
# First dimension dynamic
|
||||||
|
x = torch.rand(num_tokens, hidden_size)
|
||||||
|
torch._dynamo.mark_dynamic(x, 0)
|
||||||
|
|
||||||
|
result = model(x)
|
||||||
|
|
||||||
|
model2 = torch.compile(model, backend=backend)
|
||||||
|
result2 = model2(x)
|
||||||
|
|
||||||
|
# Check that it gives the same answer
|
||||||
|
torch.testing.assert_close(result, result2, atol=1e-3, rtol=1e-3)
|
||||||
|
|
||||||
|
# Check substitution worked
|
||||||
|
pre_nodes = backend.graph_pre_pass.nodes
|
||||||
|
post_nodes = backend.graph_post_pass.nodes
|
||||||
|
|
||||||
|
rms_quant = torch.ops._C.rms_norm_static_fp8_quant.default
|
||||||
|
add_rms_quant = torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
|
||||||
|
fp8_quant = torch.ops._C.static_scaled_fp8_quant.default
|
||||||
|
|
||||||
|
# In pre-nodes, fp8 quant should be present and fused kernels should not
|
||||||
|
assert find_auto_fn_maybe(pre_nodes, rms_quant) is None
|
||||||
|
assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None
|
||||||
|
find_auto_fn(pre_nodes, fp8_quant)
|
||||||
|
|
||||||
|
# In post-nodes, fused kernels should be present and fp8 quant should not
|
||||||
|
find_auto_fn(post_nodes, rms_quant)
|
||||||
|
find_auto_fn(post_nodes, add_rms_quant)
|
||||||
|
assert find_auto_fn_maybe(post_nodes, fp8_quant) is None
|
@ -1,13 +1,14 @@
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from tests.kernels.quant_utils import FP8_DTYPE
|
||||||
from tests.kernels.utils import opcheck
|
from tests.kernels.utils import opcheck
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
|
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
|
||||||
HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192,
|
HIDDEN_SIZES = [8, 768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192,
|
||||||
8199] # Arbitrary values for testing
|
8199] # Arbitrary values for testing
|
||||||
ADD_RESIDUAL = [False, True]
|
ADD_RESIDUAL = [False, True]
|
||||||
SEEDS = [0]
|
SEEDS = [0]
|
||||||
@ -59,3 +60,75 @@ def test_rms_norm(
|
|||||||
else:
|
else:
|
||||||
opcheck(torch.ops._C.rms_norm,
|
opcheck(torch.ops._C.rms_norm,
|
||||||
(out, x, layer.weight.data, layer.variance_epsilon))
|
(out, x, layer.weight.data, layer.variance_epsilon))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||||
|
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||||
|
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("quant_scale", [1.0, 0.01, 10.0])
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
def test_fused_rms_norm_quant(
|
||||||
|
num_tokens: int,
|
||||||
|
hidden_size: int,
|
||||||
|
add_residual: bool,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
quant_scale: float,
|
||||||
|
seed: int,
|
||||||
|
device: str,
|
||||||
|
) -> None:
|
||||||
|
current_platform.seed_everything(seed)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
|
||||||
|
weight = torch.empty(hidden_size, dtype=dtype).normal_(mean=1.0, std=0.1)
|
||||||
|
scale = 1 / (2 * hidden_size)
|
||||||
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||||
|
x *= scale
|
||||||
|
if add_residual:
|
||||||
|
residual = torch.randn_like(x) * scale
|
||||||
|
residual_fused = residual.clone()
|
||||||
|
else:
|
||||||
|
residual = residual_fused = None
|
||||||
|
|
||||||
|
out_norm = torch.empty_like(x)
|
||||||
|
out_quant = torch.empty_like(x, dtype=FP8_DTYPE)
|
||||||
|
out_quant_fused = torch.empty_like(out_quant)
|
||||||
|
|
||||||
|
quant_scale_t = torch.tensor(quant_scale, dtype=torch.float32)
|
||||||
|
|
||||||
|
if add_residual:
|
||||||
|
torch.ops._C.fused_add_rms_norm_static_fp8_quant(
|
||||||
|
out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6)
|
||||||
|
|
||||||
|
# Unfused kernel is in-place so it goes second
|
||||||
|
# Also use a separate clone of x to avoid modifying the input
|
||||||
|
x_unfused = x.clone()
|
||||||
|
torch.ops._C.fused_add_rms_norm(x_unfused, residual, weight, 1e-6)
|
||||||
|
torch.ops._C.static_scaled_fp8_quant(out_quant, x_unfused,
|
||||||
|
quant_scale_t)
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
torch.testing.assert_close(residual_fused,
|
||||||
|
residual,
|
||||||
|
atol=1e-2,
|
||||||
|
rtol=1e-2)
|
||||||
|
|
||||||
|
opcheck(
|
||||||
|
torch.ops._C.fused_add_rms_norm_static_fp8_quant,
|
||||||
|
(out_quant_fused, x, residual_fused, weight, quant_scale_t, 1e-6))
|
||||||
|
else:
|
||||||
|
torch.ops._C.rms_norm_static_fp8_quant(out_quant_fused, x, weight,
|
||||||
|
quant_scale_t, 1e-6)
|
||||||
|
|
||||||
|
torch.ops._C.rms_norm(out_norm, x, weight, 1e-6)
|
||||||
|
torch.ops._C.static_scaled_fp8_quant(out_quant, out_norm,
|
||||||
|
quant_scale_t)
|
||||||
|
|
||||||
|
opcheck(torch.ops._C.rms_norm_static_fp8_quant,
|
||||||
|
(out_quant_fused, x, weight, quant_scale_t, 1e-6))
|
||||||
|
|
||||||
|
torch.testing.assert_close(out_quant_fused.to(dtype=torch.float32),
|
||||||
|
out_quant.to(dtype=torch.float32),
|
||||||
|
atol=1e-3,
|
||||||
|
rtol=1e-3)
|
||||||
|
@ -2,7 +2,8 @@ import copy
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import operator
|
import operator
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Tuple,
|
||||||
|
Union)
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -10,11 +11,13 @@ import torch.fx as fx
|
|||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import weak_ref_tensors
|
from vllm.utils import combine_fx_passes, weak_ref_tensors
|
||||||
|
|
||||||
from .config import CompilationConfig
|
from .config import CompilationConfig
|
||||||
from .counter import compilation_counter
|
from .counter import compilation_counter
|
||||||
|
from .fusion import FusionPass
|
||||||
from .levels import CompilationLevel
|
from .levels import CompilationLevel
|
||||||
|
from .reshapes import RedundantReshapesPass
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -99,28 +102,74 @@ def fix_functionalization(graph: fx.Graph):
|
|||||||
user.replace_all_uses_with(replace_node)
|
user.replace_all_uses_with(replace_node)
|
||||||
nodes_to_remove.append(user)
|
nodes_to_remove.append(user)
|
||||||
nodes_to_remove.append(node)
|
nodes_to_remove.append(node)
|
||||||
|
elif (node.args[0] ==
|
||||||
|
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default):
|
||||||
|
# manual replace for fused_add_rms_norm_static_fp8_quant
|
||||||
|
# this is the most effective optimization for llama
|
||||||
|
# failing to do this will result in many unnecessary copies
|
||||||
|
|
||||||
|
kwargs = node.kwargs
|
||||||
|
|
||||||
|
result = kwargs['result']
|
||||||
|
residual = kwargs['residual']
|
||||||
|
|
||||||
|
# Create a new call to
|
||||||
|
# torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
|
||||||
|
with graph.inserting_before(node):
|
||||||
|
# just insert the call to the custom op
|
||||||
|
# NOTE: don't run dead code elimination,
|
||||||
|
# otherwise this op will be removed
|
||||||
|
graph.call_function(
|
||||||
|
torch.ops._C.fused_add_rms_norm_static_fp8_quant.
|
||||||
|
default,
|
||||||
|
kwargs=kwargs)
|
||||||
|
|
||||||
|
for user in list(node.users):
|
||||||
|
if user.op == 'call_function' and user.target == operator.getitem: # noqa
|
||||||
|
# Remove the getitem node
|
||||||
|
if user.args[1] == 1:
|
||||||
|
replace_node = result
|
||||||
|
elif user.args[1] == 2:
|
||||||
|
replace_node = residual
|
||||||
|
user.replace_all_uses_with(replace_node)
|
||||||
|
nodes_to_remove.append(user)
|
||||||
|
nodes_to_remove.append(node)
|
||||||
|
|
||||||
elif node.args[0] == torch.ops._C.rms_norm.default:
|
elif node.args[0] == torch.ops._C.rms_norm.default:
|
||||||
# manual replace for rms_norm
|
# manual replace for rms_norm
|
||||||
|
|
||||||
kwargs = node.kwargs
|
kwargs = node.kwargs
|
||||||
|
|
||||||
input = kwargs['input']
|
replace_node = kwargs['result']
|
||||||
out = kwargs['out']
|
# Create a new call to torch.ops._C.rms_norm.default
|
||||||
weight = kwargs['weight']
|
with graph.inserting_before(node):
|
||||||
epsilon = kwargs['epsilon']
|
# just insert the call to the custom op
|
||||||
# Create a new call to torch.ops._C.rotary_embedding.default
|
# NOTE: don't run dead code elimination,
|
||||||
# cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa
|
# otherwise this op will be removed
|
||||||
|
graph.call_function(torch.ops._C.rms_norm.default,
|
||||||
|
kwargs=kwargs)
|
||||||
|
|
||||||
|
for user in list(node.users):
|
||||||
|
if user.op == 'call_function' and user.target == operator.getitem: # noqa
|
||||||
|
user.replace_all_uses_with(replace_node)
|
||||||
|
nodes_to_remove.append(user)
|
||||||
|
nodes_to_remove.append(node)
|
||||||
|
|
||||||
|
elif node.args[
|
||||||
|
0] == torch.ops._C.rms_norm_static_fp8_quant.default: # noqa
|
||||||
|
# manual replace for rms_norm_static_fp8_quant
|
||||||
|
|
||||||
|
kwargs = node.kwargs
|
||||||
|
|
||||||
|
replace_node = kwargs['result']
|
||||||
|
# Create a new call to torch.ops._C.rms_norm_static_fp8_quant.default # noqa
|
||||||
with graph.inserting_before(node):
|
with graph.inserting_before(node):
|
||||||
# just insert the call to the custom op
|
# just insert the call to the custom op
|
||||||
# NOTE: don't run dead code elimination,
|
# NOTE: don't run dead code elimination,
|
||||||
# otherwise this op will be removed
|
# otherwise this op will be removed
|
||||||
graph.call_function(
|
graph.call_function(
|
||||||
torch.ops._C.rms_norm.default,
|
torch.ops._C.rms_norm_static_fp8_quant.default,
|
||||||
args=(out, input, weight, epsilon),
|
kwargs=kwargs)
|
||||||
)
|
|
||||||
|
|
||||||
replace_node = out
|
|
||||||
|
|
||||||
for user in list(node.users):
|
for user in list(node.users):
|
||||||
if user.op == 'call_function' and user.target == operator.getitem: # noqa
|
if user.op == 'call_function' and user.target == operator.getitem: # noqa
|
||||||
@ -136,7 +185,7 @@ def fix_functionalization(graph: fx.Graph):
|
|||||||
input = kwargs['input']
|
input = kwargs['input']
|
||||||
out = kwargs['out']
|
out = kwargs['out']
|
||||||
|
|
||||||
# Create a new call to torch.ops._C.rotary_embedding.default
|
# Create a new call to torch.ops._C.silu_and_mul.default
|
||||||
# cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa
|
# cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa
|
||||||
with graph.inserting_before(node):
|
with graph.inserting_before(node):
|
||||||
# just insert the call to the custom op
|
# just insert the call to the custom op
|
||||||
@ -319,6 +368,13 @@ class VllmBackend:
|
|||||||
|
|
||||||
The major work of this backend is to split the graph into
|
The major work of this backend is to split the graph into
|
||||||
piecewise graphs, and pass them to the piecewise backend.
|
piecewise graphs, and pass them to the piecewise backend.
|
||||||
|
|
||||||
|
This backend also handles custom passes and adds them to Inductor config.
|
||||||
|
The order of the post-grad post-passes is:
|
||||||
|
1. post_grad_passes (constructor parameter)
|
||||||
|
2. config["post_grad_custom_post_pass"]
|
||||||
|
3. fix_functionalization
|
||||||
|
This way, all passes operate on a functionalized graph.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
compilation_configs: CompilationConfig
|
compilation_configs: CompilationConfig
|
||||||
@ -330,8 +386,10 @@ class VllmBackend:
|
|||||||
split_gm: fx.GraphModule
|
split_gm: fx.GraphModule
|
||||||
piecewise_graphs: List[SplitItem]
|
piecewise_graphs: List[SplitItem]
|
||||||
returned_callable: Callable
|
returned_callable: Callable
|
||||||
|
# Inductor passes to run on the graph pre-defunctionalization
|
||||||
|
post_grad_passes: Sequence[Callable]
|
||||||
|
|
||||||
def __init__(self, ):
|
def __init__(self, post_grad_passes: Sequence[Callable] = ()):
|
||||||
global global_graph_pool
|
global global_graph_pool
|
||||||
if global_graph_pool is None:
|
if global_graph_pool is None:
|
||||||
global_graph_pool = torch.cuda.graph_pool_handle()
|
global_graph_pool = torch.cuda.graph_pool_handle()
|
||||||
@ -340,10 +398,30 @@ class VllmBackend:
|
|||||||
# streams, it might not be safe to share a global pool.
|
# streams, it might not be safe to share a global pool.
|
||||||
# only investigate this when we use multiple streams
|
# only investigate this when we use multiple streams
|
||||||
self.graph_pool = global_graph_pool
|
self.graph_pool = global_graph_pool
|
||||||
|
self.post_grad_passes = post_grad_passes
|
||||||
|
|
||||||
# `torch.compile` is JIT compiled, so we don't need to
|
# `torch.compile` is JIT compiled, so we don't need to
|
||||||
# do anything here
|
# do anything here
|
||||||
|
|
||||||
|
def add_passes_to_config(self):
|
||||||
|
config = self.compilation_configs
|
||||||
|
passes = list(self.post_grad_passes)
|
||||||
|
|
||||||
|
passes = passes + [RedundantReshapesPass(config)]
|
||||||
|
|
||||||
|
if config.enable_fusion:
|
||||||
|
passes = passes + [FusionPass.instance(config)]
|
||||||
|
|
||||||
|
inductor_config = config.inductor_compile_config
|
||||||
|
if "post_grad_custom_post_pass" in inductor_config:
|
||||||
|
passes = passes + [inductor_config["post_grad_custom_post_pass"]]
|
||||||
|
|
||||||
|
# add the fix_functionalization pass last, so that all other
|
||||||
|
# passes operate on a functionalized graph
|
||||||
|
passes = passes + [fix_functionalization]
|
||||||
|
combined_pass = combine_fx_passes(passes)
|
||||||
|
inductor_config["post_grad_custom_post_pass"] = combined_pass
|
||||||
|
|
||||||
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
|
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
|
||||||
|
|
||||||
compilation_counter.num_graphs_seen += 1
|
compilation_counter.num_graphs_seen += 1
|
||||||
@ -357,6 +435,7 @@ class VllmBackend:
|
|||||||
# we get the sizes to capture for cudagraph
|
# we get the sizes to capture for cudagraph
|
||||||
# from compilation context
|
# from compilation context
|
||||||
self.compilation_configs = CompilationConfig.select_and_init_config()
|
self.compilation_configs = CompilationConfig.select_and_init_config()
|
||||||
|
self.add_passes_to_config()
|
||||||
|
|
||||||
self.split_gm, self.piecewise_graphs = split_graph(
|
self.split_gm, self.piecewise_graphs = split_graph(
|
||||||
graph, self.compilation_configs.non_cudagraph_ops)
|
graph, self.compilation_configs.non_cudagraph_ops)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import copy
|
import copy
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, PrivateAttr
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
@ -50,6 +51,12 @@ class CompilationConfig(BaseModel):
|
|||||||
name because the config uses json format. If we pass the config
|
name because the config uses json format. If we pass the config
|
||||||
from Python, functions can also be passed directly via Python object
|
from Python, functions can also be passed directly via Python object
|
||||||
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`
|
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`
|
||||||
|
- Custom inductor passes:
|
||||||
|
- dump_graph_stages: list of stages for which we want to dump the graph.
|
||||||
|
Each pass defines its own stages (before, after, maybe in-between).
|
||||||
|
- dump_graph_dir: directory to dump the graph. Default is .
|
||||||
|
- enable_fusion: whether to enable the custom fusion pass.
|
||||||
|
TODO better pass enabling system.
|
||||||
|
|
||||||
Why we have different sizes for cudagraph and inductor:
|
Why we have different sizes for cudagraph and inductor:
|
||||||
- cudagraph: a cudagraph captured for a specific size can only be used
|
- cudagraph: a cudagraph captured for a specific size can only be used
|
||||||
@ -72,6 +79,10 @@ class CompilationConfig(BaseModel):
|
|||||||
cudagraph_num_of_warmups: int = 0
|
cudagraph_num_of_warmups: int = 0
|
||||||
cudagraph_capture_sizes: Optional[List[int]] = None
|
cudagraph_capture_sizes: Optional[List[int]] = None
|
||||||
|
|
||||||
|
dump_graph_stages: List[str] = Field(default_factory=list)
|
||||||
|
dump_graph_dir: Path = Field(default=Path("."))
|
||||||
|
enable_fusion: bool = True
|
||||||
|
|
||||||
# not configurable, computed after init
|
# not configurable, computed after init
|
||||||
compile_sizes: List[int] = PrivateAttr
|
compile_sizes: List[int] = PrivateAttr
|
||||||
capture_sizes: List[int] = PrivateAttr
|
capture_sizes: List[int] = PrivateAttr
|
||||||
@ -81,7 +92,7 @@ class CompilationConfig(BaseModel):
|
|||||||
if not isinstance(v, str):
|
if not isinstance(v, str):
|
||||||
assert callable(v), (
|
assert callable(v), (
|
||||||
f"pass {k} should be a function or a qualified name")
|
f"pass {k} should be a function or a qualified name")
|
||||||
self.inductor_passes[k] = v
|
self.inductor_compile_config[k] = v
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# resolve function from qualified name
|
# resolve function from qualified name
|
||||||
@ -91,18 +102,6 @@ class CompilationConfig(BaseModel):
|
|||||||
func = __import__(module).__dict__[func_name]
|
func = __import__(module).__dict__[func_name]
|
||||||
self.inductor_compile_config[k] = func
|
self.inductor_compile_config[k] = func
|
||||||
|
|
||||||
from vllm.compilation.backends import fix_functionalization
|
|
||||||
from vllm.utils import combine_fx_passes
|
|
||||||
if "post_grad_custom_post_pass" in self.inductor_compile_config:
|
|
||||||
self.inductor_compile_config[
|
|
||||||
"post_grad_custom_post_pass"] = combine_fx_passes(
|
|
||||||
fix_functionalization,
|
|
||||||
self.inductor_compile_config["post_grad_custom_post_pass"],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.inductor_compile_config[
|
|
||||||
"post_grad_custom_post_pass"] = fix_functionalization
|
|
||||||
|
|
||||||
def init_during_runtime(self):
|
def init_during_runtime(self):
|
||||||
"""To complete the initialization of config,
|
"""To complete the initialization of config,
|
||||||
we need to know the compile context, which is only available
|
we need to know the compile context, which is only available
|
||||||
|
291
vllm/compilation/fusion.py
Normal file
291
vllm/compilation/fusion.py
Normal file
@ -0,0 +1,291 @@
|
|||||||
|
import operator
|
||||||
|
from typing import Iterable, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||||
|
from torch._inductor.pattern_matcher import (Match, PatternMatcherPass,
|
||||||
|
fwd_only, register_replacement)
|
||||||
|
|
||||||
|
from vllm.compilation.config import CompilationConfig
|
||||||
|
from vllm.compilation.inductor_pass import InductorPass
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def rms_pattern_static(result: torch.Tensor, result_rms: torch.Tensor,
|
||||||
|
input: torch.Tensor, weight: torch.Tensor,
|
||||||
|
scale: torch.Tensor):
|
||||||
|
at1 = auto_functionalized(torch.ops._C.rms_norm.default,
|
||||||
|
result=result_rms,
|
||||||
|
input=input,
|
||||||
|
weight=weight,
|
||||||
|
epsilon=1e-5)
|
||||||
|
at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default,
|
||||||
|
result=result,
|
||||||
|
input=at1[1],
|
||||||
|
scale=scale)
|
||||||
|
|
||||||
|
# result
|
||||||
|
return at2[1]
|
||||||
|
|
||||||
|
|
||||||
|
def rms_replacement_static(result: torch.Tensor, result_rms: torch.Tensor,
|
||||||
|
input: torch.Tensor, weight: torch.Tensor,
|
||||||
|
scale: torch.Tensor):
|
||||||
|
at = auto_functionalized(torch.ops._C.rms_norm_static_fp8_quant.default,
|
||||||
|
result=result,
|
||||||
|
input=input,
|
||||||
|
weight=weight,
|
||||||
|
scale=scale,
|
||||||
|
epsilon=1e-5)
|
||||||
|
|
||||||
|
# result
|
||||||
|
return at[1]
|
||||||
|
|
||||||
|
|
||||||
|
def rms_pattern_residual_static(result: torch.Tensor, input: torch.Tensor,
|
||||||
|
residual: torch.Tensor, weight: torch.Tensor,
|
||||||
|
scale: torch.Tensor):
|
||||||
|
at = auto_functionalized(torch.ops._C.fused_add_rms_norm.default,
|
||||||
|
input=input,
|
||||||
|
residual=residual,
|
||||||
|
weight=weight,
|
||||||
|
epsilon=1e-5)
|
||||||
|
at1 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default,
|
||||||
|
result=result,
|
||||||
|
input=at[1],
|
||||||
|
scale=scale)
|
||||||
|
|
||||||
|
# result, residual
|
||||||
|
return at1[1], at[2]
|
||||||
|
|
||||||
|
|
||||||
|
def rms_replacement_residual_static(result: torch.Tensor, input: torch.Tensor,
|
||||||
|
residual: torch.Tensor,
|
||||||
|
weight: torch.Tensor, scale: torch.Tensor):
|
||||||
|
at = auto_functionalized(
|
||||||
|
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default,
|
||||||
|
result=result,
|
||||||
|
input=input,
|
||||||
|
residual=residual,
|
||||||
|
weight=weight,
|
||||||
|
scale=scale,
|
||||||
|
epsilon=1e-5)
|
||||||
|
# result, residual
|
||||||
|
return at[1], at[2]
|
||||||
|
|
||||||
|
|
||||||
|
def empty_bf16(*args, **kwargs):
|
||||||
|
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
|
||||||
|
|
||||||
|
|
||||||
|
def empty_fp8(*args, **kwargs):
|
||||||
|
fp8 = torch.float8_e4m3fn
|
||||||
|
return torch.empty(*args, **kwargs, dtype=fp8, device="cuda")
|
||||||
|
|
||||||
|
|
||||||
|
def empty_fp32(*args, **kwargs):
|
||||||
|
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
|
||||||
|
|
||||||
|
|
||||||
|
# Utilities for post-processing multi-output matches
|
||||||
|
def is_func(node: torch.fx.Node, target) -> bool:
|
||||||
|
return node.op == "call_function" and node.target == target
|
||||||
|
|
||||||
|
|
||||||
|
# Returns the first auto_functionalized node with the given op (if it exists)
|
||||||
|
def find_auto_fn_maybe(nodes: Iterable[torch.fx.Node],
|
||||||
|
op) -> Optional[torch.fx.Node]:
|
||||||
|
for node in nodes:
|
||||||
|
if is_func(node, auto_functionalized) and node.args[0] == op: # noqa
|
||||||
|
return node
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Returns the first auto_functionalized node with the given op
|
||||||
|
def find_auto_fn(nodes: Iterable[torch.fx.Node], op) -> torch.fx.Node:
|
||||||
|
node = find_auto_fn_maybe(nodes, op)
|
||||||
|
assert node is not None, f"Could not find {op} in nodes {nodes}"
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
# Returns the getitem node that extracts the idx-th element from node
|
||||||
|
# (if it exists)
|
||||||
|
def find_getitem_maybe(node: torch.fx.Node,
|
||||||
|
idx: int) -> Optional[torch.fx.Node]:
|
||||||
|
for user in node.users:
|
||||||
|
if is_func(user, operator.getitem) and user.args[1] == idx:
|
||||||
|
return user
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Returns the getitem node that extracts the idx-th element from node
|
||||||
|
def find_getitem(node: torch.fx.Node, idx: int) -> torch.fx.Node:
|
||||||
|
ret = find_getitem_maybe(node, idx)
|
||||||
|
assert ret is not None, f"Could not find getitem {idx} in node {node}"
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class FusionPass(InductorPass):
|
||||||
|
"""
|
||||||
|
This pass fuses a pre-defined set of custom ops into fused ops.
|
||||||
|
It uses the torch pattern matcher to find the patterns and replace them.
|
||||||
|
It also manually processes multi-output matches, as those are broken in
|
||||||
|
the torch pattern matcher.
|
||||||
|
|
||||||
|
Because patterns can only be registered once, the pass is a singleton.
|
||||||
|
This will be addressed in a future version of PyTorch:
|
||||||
|
https://github.com/pytorch/pytorch/pull/139321#issuecomment-2452354980
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instance: 'Optional[FusionPass]' = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def instance(cls, config: CompilationConfig):
|
||||||
|
"""
|
||||||
|
Get the singleton instance of the FusionPass.
|
||||||
|
If the instance exists, the config is updated but
|
||||||
|
initialization is not repeated.
|
||||||
|
"""
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = FusionPass(config)
|
||||||
|
else:
|
||||||
|
cls._instance.config = config
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self, config: CompilationConfig):
|
||||||
|
assert self.__class__._instance is None, \
|
||||||
|
"FusionPass singleton instance already exists"
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
self.matches: List[Match] = []
|
||||||
|
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||||
|
pass_name="fusion_pass")
|
||||||
|
|
||||||
|
# Fuse rms_norm + static_scaled_fp8_quant into
|
||||||
|
# rms_norm_static_fp8_quant
|
||||||
|
inputs = [
|
||||||
|
empty_fp8(5, 4),
|
||||||
|
empty_bf16(5, 4),
|
||||||
|
empty_bf16(5, 4),
|
||||||
|
empty_bf16(1, 5),
|
||||||
|
empty_fp32(1, 1)
|
||||||
|
]
|
||||||
|
register_replacement(rms_pattern_static, rms_replacement_static,
|
||||||
|
inputs, fwd_only, self.patterns)
|
||||||
|
|
||||||
|
# Fuse fused_add_rms_norm + static_scaled_fp8_quant into
|
||||||
|
# fused_add_rms_norm_static_fp8_quant
|
||||||
|
# Because pattern has 2 outputs, we need to manually process the match
|
||||||
|
# (see process_matches)
|
||||||
|
inputs = [
|
||||||
|
empty_fp8(5, 4),
|
||||||
|
empty_bf16(5, 4),
|
||||||
|
empty_bf16(5, 4),
|
||||||
|
empty_bf16(1, 5),
|
||||||
|
empty_fp32(1, 1)
|
||||||
|
]
|
||||||
|
register_replacement(rms_pattern_residual_static,
|
||||||
|
rms_replacement_residual_static,
|
||||||
|
inputs,
|
||||||
|
fwd_only,
|
||||||
|
self.patterns,
|
||||||
|
extra_check=lambda m: self.record_match(m))
|
||||||
|
|
||||||
|
def record_match(self, match: Match) -> bool:
|
||||||
|
# Hijack the extra_check to record the match and
|
||||||
|
# save it for post-processing.
|
||||||
|
self.matches.append(match)
|
||||||
|
|
||||||
|
# Return False to prevent automatic replacement.
|
||||||
|
return False
|
||||||
|
|
||||||
|
def process_matches(self, graph: torch.fx.Graph):
|
||||||
|
"""
|
||||||
|
Manually process multi-output matches and replace them with fused nodes.
|
||||||
|
This is necessary because the automatic replacement for multi-output
|
||||||
|
matches is broken: https://github.com/pytorch/pytorch/issues/137280
|
||||||
|
"""
|
||||||
|
for match in self.matches:
|
||||||
|
# To avoid use-before-definition errors, insert replacement nodes
|
||||||
|
# after the last node in the match.
|
||||||
|
# match.nodes is not guaranteed to be sorted.
|
||||||
|
# Find the last node in the match.
|
||||||
|
for last_node_in_match in reversed(graph.nodes):
|
||||||
|
if last_node_in_match in match.nodes:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise ValueError("No nodes in graph")
|
||||||
|
|
||||||
|
# Insert a new auto_functionalized node for the fused operation,
|
||||||
|
# as well as getitem nodes to extract the result and residual.
|
||||||
|
# The auto_functionalized node returns a tuple of
|
||||||
|
# (None, result, residual) - None is the function return value.
|
||||||
|
# The resulting graph looks like this:
|
||||||
|
# at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa
|
||||||
|
# result_node_new = at[1]
|
||||||
|
# residual_node_new = at[2]
|
||||||
|
with graph.inserting_after(last_node_in_match):
|
||||||
|
kwargs = match.kwargs
|
||||||
|
kwargs["epsilon"] = 1e-5 # Currently hard-coded in RMSNorm
|
||||||
|
|
||||||
|
fused_node = graph.call_function(
|
||||||
|
auto_functionalized,
|
||||||
|
(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default,
|
||||||
|
),
|
||||||
|
kwargs=kwargs)
|
||||||
|
|
||||||
|
graph.inserting_after(fused_node)
|
||||||
|
result_node_new = graph.call_function(operator.getitem,
|
||||||
|
(fused_node, 1))
|
||||||
|
residual_node_new = graph.call_function(
|
||||||
|
operator.getitem, (fused_node, 2))
|
||||||
|
|
||||||
|
# Last part of replacement is rebinding the users of nodes in the
|
||||||
|
# match to use the new nodes.
|
||||||
|
|
||||||
|
# Find the nodes in the match that we need to rebind
|
||||||
|
rms_node = find_auto_fn(match.nodes,
|
||||||
|
torch.ops._C.fused_add_rms_norm.default)
|
||||||
|
quant_node = find_auto_fn(
|
||||||
|
match.nodes, torch.ops._C.static_scaled_fp8_quant.default)
|
||||||
|
|
||||||
|
assert len(rms_node.users) == 2
|
||||||
|
assert len(quant_node.users) == 1
|
||||||
|
|
||||||
|
# meta["val"] is used by de-functionalization and has to contain the
|
||||||
|
# value of the node (tuple of tensors) that would be returned by the
|
||||||
|
# functionalized node during tracing.
|
||||||
|
|
||||||
|
rms_tup = rms_node.meta["val"]
|
||||||
|
quant_tup = quant_node.meta["val"]
|
||||||
|
|
||||||
|
# The result of fused_node must be a tuple with the first element
|
||||||
|
# None (the function return value) and the remaining elements
|
||||||
|
# representing the mutated inputs.
|
||||||
|
fused_tup = (None, quant_tup[1], rms_tup[1], rms_tup[2])
|
||||||
|
fused_node.meta["val"] = fused_tup
|
||||||
|
|
||||||
|
# Find the getitem nodes and replace their uses with the new nodes.
|
||||||
|
# The old nodes will be removed by DCE at the end of the pass.
|
||||||
|
find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new)
|
||||||
|
find_getitem(quant_node, 1).replace_all_uses_with(result_node_new)
|
||||||
|
|
||||||
|
# Finally, remove matched nodes
|
||||||
|
graph.eliminate_dead_code()
|
||||||
|
assert all(node not in graph.nodes for match in self.matches
|
||||||
|
for node in match.nodes)
|
||||||
|
|
||||||
|
def __call__(self, graph: torch.fx.Graph):
|
||||||
|
self.dump_graph(graph, "before_fusion")
|
||||||
|
|
||||||
|
count = self.patterns.apply(graph)
|
||||||
|
logger.info("Replaced %s patterns", count)
|
||||||
|
self.dump_graph(graph, "after_pattern_match")
|
||||||
|
|
||||||
|
# Manually process multi-output matches (and run DCE)
|
||||||
|
self.process_matches(graph)
|
||||||
|
logger.info("Post-processed %s matches", len(self.matches))
|
||||||
|
self.dump_graph(graph, "after_fusion")
|
||||||
|
self.matches.clear()
|
38
vllm/compilation/inductor_pass.py
Normal file
38
vllm/compilation/inductor_pass.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.compilation.config import CompilationConfig
|
||||||
|
# yapf: disable
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
|
||||||
|
from vllm.distributed import (
|
||||||
|
get_tensor_model_parallel_world_size as get_tp_world_size)
|
||||||
|
from vllm.distributed import model_parallel_is_initialized as p_is_init
|
||||||
|
# yapf: enable
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class InductorPass(ABC):
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __call__(self, graph: torch.fx.Graph):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __init__(self, config: CompilationConfig):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def dump_graph(self, graph: torch.fx.Graph, stage: str):
|
||||||
|
if stage in self.config.dump_graph_stages:
|
||||||
|
# Make sure filename includes rank in the distributed setting
|
||||||
|
parallel = p_is_init() and get_tp_world_size() > 1
|
||||||
|
rank = f"-{get_tp_rank()}" if parallel else ""
|
||||||
|
filepath = self.config.dump_graph_dir / f"{stage}{rank}.py"
|
||||||
|
|
||||||
|
logger.info("Printing graph to %s", filepath)
|
||||||
|
with open(filepath, "w") as f:
|
||||||
|
src = graph.python_code(root_module="self", verbose=True).src
|
||||||
|
# Add imports so it's not full of errors
|
||||||
|
print("import torch; from torch import device", file=f)
|
||||||
|
print(src, file=f)
|
85
vllm/compilation/reshapes.py
Normal file
85
vllm/compilation/reshapes.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import torch.fx
|
||||||
|
from torch import SymInt
|
||||||
|
|
||||||
|
from vllm.compilation.fusion import is_func
|
||||||
|
from vllm.compilation.inductor_pass import InductorPass
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RedundantReshapesPass(InductorPass):
|
||||||
|
"""
|
||||||
|
This is an inductor pass that removes redundant reshape operations.
|
||||||
|
It is required for RMSNorm-quant fusion to work properly.
|
||||||
|
That's because apply_fp8_linear adds a reshape, which is redundant
|
||||||
|
in the 2D-case.
|
||||||
|
|
||||||
|
Example graph:
|
||||||
|
|
||||||
|
getitem_1: "f16[s0, 4096]" = ...
|
||||||
|
view_1: "f16[s0, 4096]" = torch.reshape(getitem_1, [-1, 4096])
|
||||||
|
at = auto_functionalized(static_scaled_fp8_quant, input = view_1, ...)
|
||||||
|
out: "f8e4m3fn[s0, 4096]" = at[1]
|
||||||
|
|
||||||
|
Can be replaced with:
|
||||||
|
getitem_1: "f16[s0, 4096]" = ...
|
||||||
|
at = auto_functionalized(static_scaled_fp8_quant, input = getitem_1, ...)
|
||||||
|
out: "f8e4m3fn[s0, 4096]" = at[1]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, graph: torch.fx.Graph):
|
||||||
|
self.dump_graph(graph, "before_reshapes")
|
||||||
|
count = 0
|
||||||
|
# Remove no-op reshapes/views:
|
||||||
|
for node in graph.nodes:
|
||||||
|
if is_func(node, torch.ops.aten.reshape.default):
|
||||||
|
input, shape = node.args[:2]
|
||||||
|
input_shape = input.meta["val"].shape
|
||||||
|
if len(shape) != len(input_shape):
|
||||||
|
# Reshape changing rank, skip
|
||||||
|
continue
|
||||||
|
|
||||||
|
if shape.count(-1) > 1:
|
||||||
|
# Invalid reshape args, skip
|
||||||
|
continue
|
||||||
|
|
||||||
|
if all(
|
||||||
|
self.dims_equivalent(s, i_s)
|
||||||
|
for s, i_s in zip(shape, input_shape)):
|
||||||
|
node.replace_all_uses_with(input)
|
||||||
|
graph.erase_node(node)
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
logger.info("Removed %s no-op reshapes", count)
|
||||||
|
|
||||||
|
self.dump_graph(graph, "after_reshapes")
|
||||||
|
|
||||||
|
def dims_equivalent(self, dim: Union[int, torch.fx.Node],
|
||||||
|
i_dim: Union[int, SymInt]) -> bool:
|
||||||
|
"""
|
||||||
|
This function checks if two dimensions are equivalent.
|
||||||
|
:param dim: The dimension arg to reshape
|
||||||
|
:param i_dim: The corresponding dimension in the input tensor
|
||||||
|
:return: Are the dimensions equivalent?
|
||||||
|
|
||||||
|
There are three cases in which the dimensions are equivalent:
|
||||||
|
1. The dimensions are equal (both integers)
|
||||||
|
2. The reshape dimension is -1 (i.e. inferred)
|
||||||
|
3. The dimensions both correspond to the same SymInt
|
||||||
|
|
||||||
|
While case 2 does not guarantee the dimensions are equal,
|
||||||
|
they are equal if all other dimensions are equal.
|
||||||
|
|
||||||
|
In case 3, the reshape dimension is a torch.fx.Node,
|
||||||
|
and its value is a SymInt. That value is equal to the
|
||||||
|
input dimension.
|
||||||
|
|
||||||
|
"""
|
||||||
|
# Case 1 and 2
|
||||||
|
if dim == i_dim or dim == -1:
|
||||||
|
return True
|
||||||
|
# Case 3
|
||||||
|
return isinstance(dim, torch.fx.Node) and dim.meta["val"] == i_dim
|
@ -68,6 +68,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
|
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
|
||||||
VLLM_SKIP_P2P_CHECK: bool = False
|
VLLM_SKIP_P2P_CHECK: bool = False
|
||||||
VLLM_TORCH_COMPILE_LEVEL: int = 0
|
VLLM_TORCH_COMPILE_LEVEL: int = 0
|
||||||
|
VLLM_TORCH_COMPILE_CONFIG: Optional[str] = None
|
||||||
VLLM_CUSTOM_OPS: List[str] = []
|
VLLM_CUSTOM_OPS: List[str] = []
|
||||||
VLLM_DISABLED_KERNELS: List[str] = []
|
VLLM_DISABLED_KERNELS: List[str] = []
|
||||||
VLLM_USE_V1: bool = False
|
VLLM_USE_V1: bool = False
|
||||||
@ -226,6 +227,7 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
# and disabled when running with Inductor (compile_level >= Inductor).
|
# and disabled when running with Inductor (compile_level >= Inductor).
|
||||||
"VLLM_CUSTOM_OPS":
|
"VLLM_CUSTOM_OPS":
|
||||||
lambda: os.environ.get("VLLM_CUSTOM_OPS", "").replace(" ", "").split(","),
|
lambda: os.environ.get("VLLM_CUSTOM_OPS", "").replace(" ", "").split(","),
|
||||||
|
|
||||||
# local rank of the process in the distributed setting, used to determine
|
# local rank of the process in the distributed setting, used to determine
|
||||||
# the GPU device id
|
# the GPU device id
|
||||||
"LOCAL_RANK":
|
"LOCAL_RANK":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user