[torch.compile] Dynamic fp8 + rms_norm fusion (#10906)

Signed-off-by: luka <luka@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Luka Govedič 2024-12-12 22:19:23 -05:00 committed by GitHub
parent 78ed8f57d8
commit 30870b4f66
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 1735 additions and 251 deletions

View File

@ -196,6 +196,7 @@ set(VLLM_EXT_SRC
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
@ -300,7 +301,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
#
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
# kernels for the remaining archs that are not already built for 3x.
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
"7.5;8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}")
# subtract out the archs that are already built for 3x
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})

View File

@ -0,0 +1,173 @@
import pickle as pkl
import time
from dataclasses import dataclass
from itertools import product
from typing import Callable, Iterable, List, Optional
import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from tqdm import tqdm
import vllm._custom_ops as ops
from vllm.model_executor.layers.layernorm import RMSNorm
@dataclass
class bench_params_t:
num_tokens: int
hidden_size: int
add_residual: bool
dtype: torch.dtype
def description(self):
return (f'N {self.num_tokens} '
f'x D {self.hidden_size} '
f'x R {self.add_residual} '
f'x DT {self.dtype}')
def get_bench_params() -> List[bench_params_t]:
## Test Fixtures
NUM_TOKENS = [2**x for x in range(11)]
HIDDEN_SIZES = list(range(1024, 8129, 1024))
ADD_RESIDUAL = [True, False]
DTYPES = [torch.bfloat16, torch.float]
combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES)
bench_params = list(map(lambda x: \
bench_params_t(x[0], x[1], x[2], x[3]), combinations))
return bench_params
# Reference impls
def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor,
residual: Optional[torch.Tensor],
quant_dtype: torch.dtype):
# Norm
torch_out = None
if residual is None:
torch_out = rms_norm_layer.forward_cuda(x, residual)
else:
torch_out, _ = rms_norm_layer.forward_cuda(x, residual)
# Quant
torch_out, _, _ = ops.scaled_int8_quant(torch_out)
def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor,
residual: Optional[torch.Tensor],
quant_dtype: torch.dtype):
# Norm
torch_out = None
if residual is None:
torch_out = rms_norm_layer.forward_cuda(x, residual)
else:
torch_out, _ = rms_norm_layer.forward_cuda(x, residual)
# Quant
torch_out, _ = ops.scaled_fp8_quant(torch_out)
def fused_impl(
rms_norm_layer: RMSNorm, # this stores the weights
x: torch.Tensor,
residual: Optional[torch.Tensor],
quant_dtype: torch.dtype):
out, _ = ops.rms_norm_dynamic_per_token_quant(x,
rms_norm_layer.weight,
1e-6,
quant_dtype,
residual=residual)
# Bench functions
def bench_fn(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor,
quant_dtype: torch.dtype, label: str, sub_label: str,
fn: Callable, description: str) -> TMeasurement:
min_run_time = 1
globals = {
"rms_norm_layer": rms_norm_layer,
"x": x,
"residual": residual,
"quant_dtype": quant_dtype,
"fn": fn,
}
return TBenchmark.Timer(
stmt="fn(rms_norm_layer, x, residual, quant_dtype)",
globals=globals,
label=label,
sub_label=sub_label,
description=description,
).blocked_autorange(min_run_time=min_run_time)
def bench(params: bench_params_t, label: str, sub_label: str) \
-> Iterable[TMeasurement]:
# Make inputs
layer = RMSNorm(params.hidden_size, 1e-6).to(dtype=params.dtype)
# Make weights
layer.weight.data.normal_(mean=1.0, std=0.1)
# Make inputs
scale = 1 / params.hidden_size
x = torch.randn(params.num_tokens,
params.hidden_size,
dtype=params.dtype,
device='cuda') * scale
residual = (torch.randn_like(x) * scale).to(device='cuda') \
if params.add_residual else None
timers = []
# unfused int8 impl.
timers.append(
bench_fn(layer, x, residual, torch.int8, label, sub_label,
unfused_int8_impl, "unfused_int8_impl"))
# unfused fp8 impl.
timers.append(
bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label,
unfused_fp8_impl, "unfused_fp8_impl"))
# fused int8 impl.
timers.append(
bench_fn(layer, x, residual, torch.int8, label, sub_label, fused_impl,
"fused_int8_impl"))
# fused fp8 impl.
timers.append(
bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label,
fused_impl, "fused_fp8_impl"))
print_timers(timers)
return timers
# launch bench
# runner
def print_timers(timers: Iterable[TMeasurement]):
compare = TBenchmark.Compare(timers)
compare.print()
def main():
torch.set_default_device('cuda')
bench_params = get_bench_params()
timers = []
for bp in tqdm(bench_params):
timers.extend(
bench(bp, "rms-norm-dynamic-per-token-quant", bp.description()))
print_timers(timers)
# pickle all the results
timestamp = int(time.time())
with open(f"rms_norm_dpt_quant-{timestamp}.pkl", "wb") as f:
pkl.dump(timers, f)
if __name__ == '__main__':
main()

View File

@ -14,6 +14,20 @@
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
// TODO(luka/varun): use FP8_TYPE macro after refactoring
#ifndef USE_ROCM
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
#else
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
#endif
#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \

View File

@ -66,6 +66,14 @@ void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out,
torch::Tensor& weight,
torch::Tensor& scale, double epsilon);
void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
torch::Tensor const& input,
torch::Tensor const& weight,
torch::Tensor& scales,
double const epsilon,
std::optional<torch::Tensor> scale_ub,
std::optional<torch::Tensor> residual);
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox);

View File

@ -1,6 +1,9 @@
#pragma once
#include "quantization/vectorization.cuh"
#include <cmath>
#include <c10/core/ScalarType.h>
#ifndef USE_ROCM
#include <c10/util/Float8_e4m3fn.h>
@ -15,6 +18,7 @@ using FP8_TYPE = c10::Float8_e4m3fnuz;
// issue when running dynamic quantization. Here use 224.0f for rocm.
constexpr auto FP8_E4M3_MAX = 224.0f;
#endif
constexpr static auto kFp8Type = c10::CppTypeToScalarType<FP8_TYPE>::value;
namespace vllm {
@ -89,22 +93,6 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
}
}
template <typename scalar_t>
struct __align__(8) vec4_t {
scalar_t x;
scalar_t y;
scalar_t z;
scalar_t w;
};
typedef struct __align__(4) {
FP8_TYPE x;
FP8_TYPE y;
FP8_TYPE z;
FP8_TYPE w;
}
float8x4_t;
template <typename scalar_t>
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
int64_t const num_elems, int const tid,
@ -139,10 +127,10 @@ __device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
float const scale,
int64_t const num_elems,
int const tid, int const step) {
using float8x4_t = q8x4_t<FP8_TYPE>;
// Vectorized input/output to better utilize memory bandwidth.
vec4_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec4_t<scalar_t> const*>(input);
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
auto const* vectorized_in = reinterpret_cast<vec4_t<scalar_t> const*>(input);
auto* vectorized_out = reinterpret_cast<float8x4_t*>(out);
int64_t const num_vec_elems = num_elems >> 2;

View File

@ -0,0 +1,160 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "../../dispatch_utils.h"
#include "layernorm_utils.cuh"
#include "quant_conversions.cuh"
namespace vllm {
template <typename scalar_t, typename scalar_out_t, bool has_residual = false>
__device__ void rms_norm_dynamic_per_token_quant_vec(
scalar_out_t* __restrict__ out, // [..., hidden_size]
float* __restrict__ scales, // [num_tokens]
scalar_t const* __restrict__ input, // [..., hidden_size]
scalar_t const* __restrict__ weight, // [hidden_size]
float const* scale_ub, float const var_epsilon,
float const min_scaling_factor, int32_t const hidden_size,
scalar_t* __restrict__ residual = nullptr) {
float rms = 0.0f;
float token_scale = 0.0f;
// Compute rms
vllm::vectorized::compute_rms<scalar_t, has_residual>(
&rms, input, hidden_size, var_epsilon, residual);
// Compute scale
vllm::vectorized::compute_dynamic_per_token_scales<scalar_t, scalar_out_t,
has_residual>(
&token_scale, scales, input, weight, rms, scale_ub, min_scaling_factor,
hidden_size, residual);
// RMS Norm + Quant
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, true,
has_residual>(
out, input, weight, rms, 1.0f / token_scale, hidden_size, residual);
} else {
// FP8 - Do not invert token_scale for exact match with FBGemm
vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, false,
has_residual>(
out, input, weight, rms, token_scale, hidden_size, residual);
}
}
// RMS norm + quant kernel
template <typename scalar_t, typename scalar_out_t, bool has_residual = false>
__global__ void rms_norm_dynamic_per_token_quant_kernel(
scalar_out_t* __restrict__ out, // [..., hidden_size]
float* __restrict__ scales, // [num_tokens]
scalar_t const* __restrict__ input, // [..., hidden_size]
scalar_t const* __restrict__ weight, // [hidden_size]
float const* scale_ub, float const var_epsilon,
float const min_scaling_factor, int32_t const hidden_size,
scalar_t* __restrict__ residual = nullptr) {
// For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively.
bool const can_vectorize = hidden_size % 4 == 0;
if (can_vectorize) {
return rms_norm_dynamic_per_token_quant_vec<scalar_t, scalar_out_t,
has_residual>(
out, scales, input, weight, scale_ub, var_epsilon, min_scaling_factor,
hidden_size, residual);
}
float rms = 0.0f;
float token_scale = 0.0f;
// Compute RMS
vllm::compute_rms<scalar_t, has_residual>(&rms, input, hidden_size,
var_epsilon, residual);
// Compute Scale
vllm::compute_dynamic_per_token_scales<scalar_t, scalar_out_t, has_residual>(
&token_scale, scales, input, weight, rms, scale_ub, min_scaling_factor,
hidden_size, residual);
// RMS Norm + Quant
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
vllm::norm_and_quant<scalar_t, scalar_out_t, true, has_residual>(
out, input, weight, rms, 1.0f / token_scale, hidden_size, residual);
} else {
// FP8 - Do not invert s_token_scale for exact match with FBGemm
vllm::norm_and_quant<scalar_t, scalar_out_t, false, has_residual>(
out, input, weight, rms, token_scale, hidden_size, residual);
}
}
} // namespace vllm
// Residual add + RMS norm + dynamic per token
template <typename scalar_in_t>
void rms_norm_dynamic_per_token_quant_dispatch(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor const& input, // [..., hidden_size]
torch::Tensor const& weight, // [hidden_size]
torch::Tensor& scales, // [num_tokens]
double const var_epsilon, // Variance epsilon used in norm calculation
std::optional<at::Tensor> const& scale_ub,
std::optional<at::Tensor>& residual) {
int32_t hidden_size = input.size(-1);
int32_t num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const float min_scaling_factor =
out.dtype() == torch::kInt8
? std::numeric_limits<float>::epsilon()
: 1.0f / (std::numeric_limits<c10::Float8_e4m3fn>::max() * 512.f);
if (residual.has_value()) {
VLLM_DISPATCH_QUANT_TYPES(
out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] {
vllm::rms_norm_dynamic_per_token_quant_kernel<scalar_in_t, scalar_t,
true>
<<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
var_epsilon, min_scaling_factor, hidden_size,
residual->data_ptr<scalar_in_t>());
});
} else {
VLLM_DISPATCH_QUANT_TYPES(
out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] {
vllm::rms_norm_dynamic_per_token_quant_kernel<scalar_in_t, scalar_t,
false>
<<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
var_epsilon, min_scaling_factor, hidden_size, nullptr);
});
}
}
void rms_norm_dynamic_per_token_quant(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor const& input, // [..., hidden_size]
torch::Tensor const& weight, // [hidden_size]
torch::Tensor& scales, // [num_tokens]
double const var_epsilon, // Variance epsilon used in norm calculation
std::optional<at::Tensor> scale_ub, std::optional<at::Tensor> residual) {
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
if (scale_ub.has_value()) {
TORCH_CHECK(out.dtype() == kFp8Type);
}
TORCH_CHECK(scales.dtype() == torch::kFloat32);
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "rms_norm_dynamic_per_token_quant_dispatch", [&] {
rms_norm_dynamic_per_token_quant_dispatch<scalar_t>(
out, input, weight, scales, var_epsilon, scale_ub, residual);
});
}

View File

@ -0,0 +1,327 @@
#pragma once
/**
* __device__ layernorm utilities.
*/
#include "quantization/vectorization.cuh"
#include "quant_conversions.cuh"
#ifndef USE_ROCM
#include <cub/cub.cuh>
#else
#include <hipcub/hipcub.hpp>
#endif
namespace vllm {
// has_residual must be true, if residual is not a nullptr
template <typename scalar_t, bool has_residual = false>
__device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
int32_t const hidden_size, float const epsilon,
scalar_t const* __restrict__ residual = nullptr) {
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
// sum of squares
float ss = 0.0f;
for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) {
float x = static_cast<float>(input[token_offset + i]);
if constexpr (has_residual) {
x += static_cast<float>(residual[token_offset + i]);
}
ss += x * x;
}
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
ss = BlockReduce(reduceStore).Reduce(ss, cub::Sum{}, blockDim.x);
__shared__ float s_rms;
if (threadIdx.x == 0) {
s_rms = rsqrtf(ss / hidden_size + epsilon);
}
__syncthreads();
*rms = s_rms;
}
template <typename scalar_t, typename scalar_out_t, bool has_residual = false>
__device__ void compute_dynamic_per_token_scales(
float* __restrict__ token_scale, float* __restrict__ all_token_scales,
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
float const rms, float const* __restrict__ scale_ub,
float const min_scaling_factor, int32_t const hidden_size,
scalar_t const* __restrict__ residual = nullptr) {
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
;
constexpr scalar_out_t qmax{std::numeric_limits<scalar_out_t>::max()};
float block_absmax_val_maybe = 0.0f;
for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) {
float x = static_cast<float>(input[token_offset + i]);
if constexpr (has_residual) {
x += static_cast<float>(residual[token_offset + i]);
}
x = static_cast<float>(static_cast<scalar_t>(x * rms) * weight[i]);
block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x));
}
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
block_absmax_val_maybe =
BlockReduce(reduceStore)
.Reduce(block_absmax_val_maybe, cub::Max{}, blockDim.x);
__shared__ float s_token_scale;
if (threadIdx.x == 0) {
float scale = 0.0f;
if (scale_ub) {
scale = min(block_absmax_val_maybe, *scale_ub);
} else {
scale = block_absmax_val_maybe;
}
// token scale computation
scale = max(scale / qmax, min_scaling_factor);
s_token_scale = scale; // Shared memory store
all_token_scales[blockIdx.x] = scale; // Global output store
}
__syncthreads();
*token_scale = s_token_scale;
}
template <typename scalar_t, typename scalar_out_t, bool is_scale_inverted,
bool has_residual = false>
__device__ void norm_and_quant(scalar_out_t* __restrict__ output,
scalar_t const* __restrict__ input,
scalar_t const* __restrict__ weight,
float const rms, float const scale,
int32_t const hidden_size,
scalar_t* __restrict__ residual = nullptr) {
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
;
for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) {
float x = static_cast<float>(input[token_offset + i]);
if constexpr (has_residual) {
x += static_cast<float>(residual[token_offset + i]);
residual[token_offset + i] = static_cast<scalar_t>(x);
}
// Norm
x = static_cast<float>(static_cast<scalar_t>(x * rms) * weight[i]);
// Quant
output[token_offset + i] =
ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(x, scale);
}
}
namespace vectorized {
// Compute 1.0/rms(input)
// hidden_size must be a multiple of 4
template <typename scalar_t, bool has_residual = false>
__device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
int32_t const hidden_size, float const epsilon,
scalar_t const* __restrict__ residual = nullptr) {
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
// Vectorized input/output to better utilize memory bandwidth.
vec4_t<scalar_t> const* vec_input =
reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]);
vec4_t<scalar_t> const* vec_residual = nullptr;
if constexpr (has_residual) {
vec_residual =
reinterpret_cast<vec4_t<scalar_t> const*>(&residual[token_offset]);
}
// sum of squares
float ss = 0.0f;
int32_t const num_vec_elems = hidden_size >> 2;
#pragma unroll 4
for (int32_t i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
vec4_t<scalar_t> in = vec_input[i];
vec4_t<float> x;
x.x = static_cast<float>(in.x);
x.y = static_cast<float>(in.y);
x.z = static_cast<float>(in.z);
x.w = static_cast<float>(in.w);
if constexpr (has_residual) {
vec4_t<scalar_t> r = vec_residual[i];
x.x += static_cast<float>(r.x);
x.y += static_cast<float>(r.y);
x.z += static_cast<float>(r.z);
x.w += static_cast<float>(r.w);
}
ss += x.x * x.x;
ss += x.y * x.y;
ss += x.z * x.z;
ss += x.w * x.w;
}
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
ss = BlockReduce(reduceStore).Reduce(ss, cub::Sum{}, blockDim.x);
__shared__ float s_rms;
if (threadIdx.x == 0) {
s_rms = rsqrtf(ss / hidden_size + epsilon);
}
__syncthreads();
*rms = s_rms;
}
// Vectorized version of vllm::compute_dynamic_per_token_scales
// hidden_size must be a multiple of 4
template <typename scalar_t, typename scalar_out_t, bool has_residual = false>
__device__ void compute_dynamic_per_token_scales(
float* __restrict__ token_scale, float* __restrict__ all_token_scales,
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
float const rms, float const* __restrict__ scale_ub,
float const min_scaling_factor, int32_t const hidden_size,
scalar_t const* __restrict__ residual = nullptr) {
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
;
// Vectorized input/weight/residual to better utilize memory bandwidth.
vec4_t<scalar_t> const* vec_input =
reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]);
vec4_t<scalar_t> const* vec_weight =
reinterpret_cast<vec4_t<scalar_t> const*>(weight);
vec4_t<scalar_t> const* vec_residual = nullptr;
if constexpr (has_residual) {
vec_residual =
reinterpret_cast<vec4_t<scalar_t> const*>(&residual[token_offset]);
}
constexpr scalar_out_t qmax{std::numeric_limits<scalar_out_t>::max()};
int32_t const num_vec_elems = hidden_size >> 2;
float block_absmax_val_maybe = 0.0f;
#pragma unroll 4
for (int32_t i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
vec4_t<scalar_t> in = vec_input[i];
vec4_t<scalar_t> const w = vec_weight[i];
vec4_t<float> x;
x.x = static_cast<float>(in.x);
x.y = static_cast<float>(in.y);
x.z = static_cast<float>(in.z);
x.w = static_cast<float>(in.w);
if constexpr (has_residual) {
vec4_t<scalar_t> r = vec_residual[i];
x.x += static_cast<float>(r.x);
x.y += static_cast<float>(r.y);
x.z += static_cast<float>(r.z);
x.w += static_cast<float>(r.w);
}
block_absmax_val_maybe = fmaxf(
block_absmax_val_maybe, fabs(static_cast<scalar_t>(x.x * rms) * w.x));
block_absmax_val_maybe = fmaxf(
block_absmax_val_maybe, fabs(static_cast<scalar_t>(x.y * rms) * w.y));
block_absmax_val_maybe = fmaxf(
block_absmax_val_maybe, fabs(static_cast<scalar_t>(x.z * rms) * w.z));
block_absmax_val_maybe = fmaxf(
block_absmax_val_maybe, fabs(static_cast<scalar_t>(x.w * rms) * w.w));
}
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
block_absmax_val_maybe =
BlockReduce(reduceStore)
.Reduce(block_absmax_val_maybe, cub::Max{}, blockDim.x);
__shared__ float s_token_scale;
if (threadIdx.x == 0) {
float scale = 0.0f;
if (scale_ub) {
scale = min(block_absmax_val_maybe, *scale_ub);
} else {
scale = block_absmax_val_maybe;
}
// token scale computation
scale = max(scale / qmax, min_scaling_factor);
s_token_scale = scale; // shared memory store
all_token_scales[blockIdx.x] = scale; // global output store
}
__syncthreads();
*token_scale = s_token_scale;
}
// hidden_size must be a multiple of 4
template <typename scalar_t, typename scalar_out_t, bool is_scale_inverted,
bool has_residual = false>
__device__ void norm_and_quant(scalar_out_t* __restrict__ output,
scalar_t const* __restrict__ input,
scalar_t const* __restrict__ weight,
float const rms, float const scale,
int32_t const hidden_size,
scalar_t* __restrict__ residual = nullptr) {
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
;
// Vectorized input/output/weight/residual to better utilize memory bandwidth.
vec4_t<scalar_t> const* vec_input =
reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]);
vec4_t<scalar_t> const* vec_weight =
reinterpret_cast<vec4_t<scalar_t> const*>(weight);
q8x4_t<scalar_out_t>* vec_output =
reinterpret_cast<q8x4_t<scalar_out_t>*>(&output[token_offset]);
vec4_t<scalar_t>* vec_residual = nullptr;
if constexpr (has_residual) {
vec_residual = reinterpret_cast<vec4_t<scalar_t>*>(&residual[token_offset]);
}
int32_t const num_vec_elems = hidden_size >> 2;
// TODO(luka/varun) extract into type-agnostic vectorized quant function to
// replace scaled_fp8_conversion_vec
#pragma unroll 4
for (int32_t i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
vec4_t<scalar_t> const in = vec_input[i];
vec4_t<scalar_t> const w = vec_weight[i];
vec4_t<float> x;
x.x = static_cast<float>(in.x);
x.y = static_cast<float>(in.y);
x.z = static_cast<float>(in.z);
x.w = static_cast<float>(in.w);
if constexpr (has_residual) {
vec4_t<scalar_t> r = vec_residual[i];
x.x += static_cast<float>(r.x);
x.y += static_cast<float>(r.y);
x.z += static_cast<float>(r.z);
x.w += static_cast<float>(r.w);
// Update residual
r.x = static_cast<scalar_t>(x.x);
r.y = static_cast<scalar_t>(x.y);
r.z = static_cast<scalar_t>(x.z);
r.w = static_cast<scalar_t>(x.w);
vec_residual[i] = r;
}
q8x4_t<scalar_out_t> out;
out.x = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
static_cast<scalar_t>(x.x * rms) * w.x, scale);
out.y = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
static_cast<scalar_t>(x.y * rms) * w.y, scale);
out.z = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
static_cast<scalar_t>(x.z * rms) * w.z, scale);
out.w = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
static_cast<scalar_t>(x.w * rms) * w.w, scale);
vec_output[i] = out;
}
}
} // namespace vectorized
} // namespace vllm

View File

@ -0,0 +1,81 @@
#pragma once
/**
* __device__ helper functions to deal with float -> quant datatype conversion
*/
#include "quantization/vectorization.cuh"
// TODO(luka/varun):refactor common.cuh to use this file instead
#include "quantization/fp8/common.cuh"
namespace vllm {
// TODO(luka/varun): combine into common utilities for int8
// (with int8_quant_kernels.cu)
static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) {
#ifdef USE_ROCM
static const float i8_min =
static_cast<float>(std::numeric_limits<int8_t>::min());
static const float i8_max =
static_cast<float>(std::numeric_limits<int8_t>::max());
// round
float dst = std::nearbyint(x);
// saturate
dst = std::clamp(dst, i8_min, i8_max);
return static_cast<int8_t>(dst);
#else
// CUDA path
uint32_t dst;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
return reinterpret_cast<const int8_t&>(dst);
#endif
}
static __device__ __forceinline__ FP8_TYPE float_to_fp8(float const x) {
float const r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
return static_cast<FP8_TYPE>(r);
}
template <typename quant_type_t, bool is_scale_inverted, typename enable = void>
struct ScaledQuant;
template <typename quant_type_t, bool is_scale_inverted>
struct ScaledQuant<
quant_type_t, is_scale_inverted,
typename std::enable_if_t<std::is_same_v<quant_type_t, int8_t>>> {
static __device__ __forceinline__ quant_type_t quant_fn(float const x,
float const scale) {
if constexpr (is_scale_inverted) {
return float_to_int8_rn(x * scale);
} else {
return float_to_int8_rn(x / scale);
}
}
};
template <typename quant_type_t, bool is_scale_inverted>
struct ScaledQuant<
quant_type_t, is_scale_inverted,
typename std::enable_if_t<std::is_same_v<quant_type_t, FP8_TYPE>>> {
static __device__ __forceinline__ quant_type_t quant_fn(float const x,
float const scale) {
if constexpr (is_scale_inverted) {
return float_to_fp8(x * scale);
} else {
return float_to_fp8(x / scale);
}
}
};
template <typename scalar_t, typename quant_type_t, bool is_scale_inverted>
__device__ void scaled_quant_conversion(quant_type_t* __restrict__ output,
scalar_t const* __restrict__ input,
float const scale, int const tid,
int const num_elements,
int const step) {
for (int i = tid; i < num_elements; i += step) {
output[i] = ScaledQuant<quant_type_t, is_scale_inverted>(input[i], scale);
}
}
} // namespace vllm

View File

@ -0,0 +1,33 @@
#pragma once
/**
* __device__ datatypes vectorized by 4
*/
// Include both AMD and NVIDIA fp8 types to avoid circular import
// TODO(luka/varun) use FP8_TYPE instead after refactoring
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e4m3fn.h>
namespace vllm {
// Vectorization containers
template <typename scalar_t>
struct __align__(8) vec4_t {
scalar_t x;
scalar_t y;
scalar_t z;
scalar_t w;
};
template <typename quant_type_t>
struct __align__(4) q8x4_t {
static_assert(std::is_same_v<quant_type_t, int8_t> ||
std::is_same_v<quant_type_t, c10::Float8_e4m3fn> ||
std::is_same_v<quant_type_t, c10::Float8_e4m3fnuz>);
quant_type_t x;
quant_type_t y;
quant_type_t z;
quant_type_t w;
};
} // namespace vllm

View File

@ -128,6 +128,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA,
&fused_add_rms_norm_static_fp8_quant);
// Fused Layernorm + Quant kernels
ops.def(
"rms_norm_dynamic_per_token_quant(Tensor! result, Tensor input, "
"Tensor weight, Tensor! scale, float epsilon, "
"Tensor? scale_ub, Tensor!? residual) -> ()");
ops.impl("rms_norm_dynamic_per_token_quant", torch::kCUDA,
&rms_norm_dynamic_per_token_quant);
// Rotary embedding
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
ops.def(

View File

@ -4,10 +4,10 @@ import torch
import vllm.envs as envs
from vllm import LLM, SamplingParams
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import (FusionPass, find_auto_fn,
find_auto_fn_maybe)
from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
kFp8DynamicTokenSym, kFp8StaticTensorSym)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.reshapes import RedundantReshapesPass
from vllm.compilation.vllm_inductor_pass import is_func
from vllm.config import CompilationConfig
from .backend import TestBackend
@ -35,12 +35,16 @@ prompts = [
]
@pytest.mark.parametrize("model",
["nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"])
@pytest.mark.parametrize(
"model, quant_key",
[("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e", kFp8StaticTensorSym),
("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8_DYNAMIC-e2e",
kFp8DynamicTokenSym)])
@pytest.mark.parametrize("do_fusion", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
reason="Only test on CUDA")
def test_fix_functionalization(model: str, do_fusion: bool):
def test_fix_functionalization(model: str, quant_key: QuantKey,
do_fusion: bool):
torch.set_default_device("cuda")
config = CompilationConfig.PassConfig(enable_fusion=do_fusion,
@ -78,8 +82,9 @@ def test_fix_functionalization(model: str, do_fusion: bool):
# OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion,
# and replaced by fused quantized ops in RMS_QUANT_OPS.
ops = OPS_IN_MODEL + (RMS_QUANT_OPS["static_fp8"]
if do_fusion else [RMS_OP])
rms_ops = [FUSED_OPS[(quant_key, True)], FUSED_OPS[(quant_key, False)]
] if do_fusion else [RMS_OP]
ops = OPS_IN_MODEL + rms_ops
for op in ops:
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)

View File

@ -3,8 +3,9 @@ import torch
from compressed_tensors.quantization import FP8_DTYPE
import vllm.envs as envs
from vllm.compilation.fusion import (FusionPass, find_auto_fn,
find_auto_fn_maybe)
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
FusionPass, QuantKey)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
from vllm.compilation.reshapes import RedundantReshapesPass
from vllm.config import CompilationConfig
from vllm.model_executor.layers.layernorm import RMSNorm
@ -16,24 +17,37 @@ from .backend import TestBackend
class TestModel(torch.nn.Module):
def __init__(self, hidden_size: int, eps: float, *args, **kwargs):
def __init__(self, hidden_size: int, eps: float, static: bool, *args,
**kwargs):
super().__init__(*args, **kwargs)
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(4)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
if static:
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
else:
self.scale = [None for _ in range(2)]
self.w = [
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
for _ in range(2)
]
def forward(self, x):
resid = torch.relu(x)
resid = torch.sqrt(x)
y = self.norm[0](x)
x2 = apply_fp8_linear(y, self.w[0], self.scale[0], self.scale[1])
x2 = apply_fp8_linear(y,
self.w[0],
self.wscale[0],
self.scale[0],
use_per_token_if_dynamic=True)
# make sure resid is used for replacement to work
y2, resid = self.norm[1](x2, resid)
x3 = apply_fp8_linear(y2, self.w[1], self.scale[2], self.scale[3])
x3 = apply_fp8_linear(y2,
self.w[1],
self.wscale[1],
self.scale[1],
use_per_token_if_dynamic=True)
y3, resid = self.norm[2](x3, resid) # use resid here
return y3
@ -42,14 +56,13 @@ class TestModel(torch.nn.Module):
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096])
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("static", [True, False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
reason="Only test on CUDA")
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps):
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static):
torch.set_default_device("cuda")
torch.set_default_dtype(torch.float16)
if eps != 1e-5:
pytest.skip("Only test eps=1e-5 for now")
torch.set_default_dtype(dtype)
torch.manual_seed(1)
# Reshape pass is needed for the fusion pass to work
config = CompilationConfig.PassConfig(enable_fusion=True,
@ -58,7 +71,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps):
fusion_pass = FusionPass.instance(config)
backend = TestBackend(reshape_pass, fusion_pass)
model = TestModel(hidden_size, eps)
model = TestModel(hidden_size, eps, static)
# First dimension dynamic
x = torch.rand(num_tokens, hidden_size)
@ -69,16 +82,28 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps):
model2 = torch.compile(model, backend=backend)
result2 = model2(x)
# Check that it gives the same answer
torch.testing.assert_close(result, result2, atol=1e-3, rtol=1e-3)
# Higher tol for dynamic, even higher for bfloat16
if static:
ATOL, RTOL = (1e-3, 1e-3)
elif dtype == torch.float16:
ATOL, RTOL = (2e-3, 2e-3)
else:
ATOL, RTOL = (1e-2, 1e-2)
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
# Check substitution worked
pre_nodes = backend.graph_pre_pass.nodes
post_nodes = backend.graph_post_pass.nodes
rms_quant = torch.ops._C.rms_norm_static_fp8_quant.default
add_rms_quant = torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
fp8_quant = torch.ops._C.static_scaled_fp8_quant.default
# static is per-tensor, dynamic is per-token
key = QuantKey(dtype=FP8_DTYPE,
static=static,
per_tensor=static,
symmetric=True)
rms_quant = FUSED_OPS[FusedRMSQuantKey(key, False)]
add_rms_quant = FUSED_OPS[FusedRMSQuantKey(key, True)]
fp8_quant = QUANT_OPS[key]
# In pre-nodes, fp8 quant should be present and fused kernels should not
assert find_auto_fn_maybe(pre_nodes, rms_quant) is None

View File

@ -0,0 +1,171 @@
from typing import Optional, Tuple, Union
import pytest
import torch
import vllm._custom_ops as ops
from tests.kernels.utils import opcheck
from vllm.model_executor.layers.layernorm import RMSNorm
DTYPES = [torch.bfloat16, torch.float]
QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn]
VEC_HIDDEN_SIZES = range(1024, 1030)
# Avoid combinatorial explosion with full Cartesian product
NUM_TOKENS_HIDDEN_SIZES = [
*[(1, i) for i in [1, 64, *VEC_HIDDEN_SIZES, 5120, 5137]],
*[(83, i) for i in [1, 1033, 2048, 5120]],
*[(2048, i) for i in [1, 64, *VEC_HIDDEN_SIZES, 5137]],
*[(4096, i) for i in [1, 64, 5137]],
]
ADD_RESIDUAL = [False, True]
SCALE_UBS = [True, False]
SEEDS = [0]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
EPS = 1e-6
## Helpers
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
return torch.as_tensor(x, dtype=torch.float32, device='cuda')
def ref_rms_norm(rms_norm_layer: RMSNorm,
x: torch.Tensor,
residual: Optional[torch.Tensor]) \
-> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if residual is not None:
residual = residual.clone()
out, residual = rms_norm_layer.forward_native(x, residual)
else:
out = rms_norm_layer.forward_native(x)
return out, residual
def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm,
x: torch.Tensor,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor]) \
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
if scale_ub is not None:
assert quant_dtype == torch.float8_e4m3fn
# Norm
torch_out, residual = ref_rms_norm(rms_norm_layer, x, residual)
# Quant
if quant_dtype == torch.float8_e4m3fn:
torch_out, scales = ops.scaled_fp8_quant(torch_out,
scale_ub=scale_ub,
use_per_token_if_dynamic=True)
else:
assert quant_dtype == torch.int8
torch_out, scales = ops.scaled_int8_quant(torch_out)
return torch_out, scales, residual
def ref_impl(rms_norm_layer: RMSNorm,
x: torch.Tensor,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor]) \
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return ref_dynamic_per_token_quant(rms_norm_layer, x, quant_dtype,
residual, scale_ub)
def ops_dynamic_per_token_quant(weight: torch.Tensor,
x: torch.Tensor,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor]) \
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
if residual is not None:
residual = residual.clone()
out, scales = ops.rms_norm_dynamic_per_token_quant(x, weight, EPS,
quant_dtype, scale_ub,
residual)
return out, scales, residual
def ops_impl(weight: torch.Tensor,
x: torch.Tensor,
quant_dtype: torch.dtype,
residual: Optional[torch.Tensor],
scale_ub: Optional[torch.Tensor]) \
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual,
scale_ub)
@pytest.mark.parametrize("num_tokens, hidden_size", NUM_TOKENS_HIDDEN_SIZES)
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
@pytest.mark.parametrize("scale_ub", SCALE_UBS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_dtype", QUANT_DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_rms_norm(
num_tokens: int,
hidden_size: int,
add_residual: bool,
scale_ub: bool,
dtype: torch.dtype,
quant_dtype: torch.dtype,
seed: int,
device: str,
) -> None:
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device(device)
if scale_ub is not None and quant_dtype != torch.float8_e4m3fn:
# skip
return
layer = RMSNorm(hidden_size, EPS).to(dtype=dtype)
# Make weights
layer.weight.data.normal_(mean=1.0, std=0.1)
# Make inputs
scale = 1 / (hidden_size)
x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale
residual = torch.randn_like(x) * scale if add_residual else None
if scale_ub is not None:
rms_x, _ = ref_rms_norm(layer, x, residual)
scale_ub = torch.mean(rms_x).to(dtype=torch.float32, device='cuda')
ref_out, ref_scales, ref_residual = \
ref_impl(layer, x, quant_dtype, residual, scale_ub)
ops_out, ops_scales, ops_residual = \
ops_impl(layer.weight, x, quant_dtype, residual, scale_ub)
assert ref_out.dtype == quant_dtype
assert ops_out.dtype == quant_dtype
assert torch.allclose(ref_scales, ops_scales)
if quant_dtype == torch.int8:
# big atol to account for round-off errors.
assert torch.allclose(ref_out, ops_out, atol=1)
else:
assert torch.allclose(ref_out.to(dtype=torch.float32),
ops_out.to(dtype=torch.float32))
if add_residual:
assert torch.allclose(ref_residual, ops_residual)
output = torch.empty_like(x, dtype=quant_dtype)
scales = torch.empty((x.numel() // x.shape[-1], 1),
device=x.device,
dtype=torch.float32)
opcheck(torch.ops._C.rms_norm_dynamic_per_token_quant,
(output, x, layer.weight, scales, 1e-5, scale_ub, residual))

View File

@ -249,6 +249,26 @@ def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
block_table_bound)
# fused quant layer norm ops
def rms_norm_dynamic_per_token_quant(
input: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
quant_dtype: torch.dtype,
scale_ub: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
output = torch.empty_like(input, dtype=quant_dtype)
scales = torch.empty((input.numel() // input.shape[-1], 1),
device=input.device,
dtype=torch.float32)
torch.ops._C.rms_norm_dynamic_per_token_quant(output, input, weight,
scales, epsilon, scale_ub,
residual)
return output, scales
# quantization ops
# awq
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,

View File

@ -6,7 +6,8 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
from vllm.logger import init_logger
from .vllm_inductor_pass import VllmInductorPass, is_func
from .fx_utils import is_func
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
@ -53,14 +54,16 @@ class FixFunctionalizationPass(VllmInductorPass):
self.insert_defunctionalized(graph, node)
self._remove(node)
# These 2 replacements avoid the most copies for LLaMa.
# rms_norm replacements avoid the most copies for LLaMa.
elif at_target == torch.ops._C.fused_add_rms_norm.default:
mutated_args = {1: 'input', 2: 'residual'}
self.defunctionalize(graph, node, mutated_args)
elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
mutated_args = {1: 'result', 2: 'residual'}
self.defunctionalize(graph, node, mutated_args)
elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501
mutated_args = {1: 'result', 2: 'scale', 3: 'residual'}
self.defunctionalize(graph, node, mutated_args)
elif at_target in [
torch.ops._C.rms_norm.default,
torch.ops._C.rms_norm_static_fp8_quant.default

View File

@ -1,129 +1,517 @@
import operator
from typing import Iterable, List, Optional
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
import torch
import torch._inductor.pattern_matcher as pm
# TODO(luka) use vllm.utils once #10836 landed
from compressed_tensors.quantization import FP8_DTYPE
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import (Match, PatternMatcherPass,
fwd_only, register_replacement)
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch._ops import OpOverload
from vllm.config import CompilationConfig
from vllm.logger import init_logger
from .vllm_inductor_pass import VllmInductorPass, is_func
from .fx_utils import find_getitem_maybe
from .multi_output_match import MultiOutputMatch
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)
def rms_pattern_static(result: torch.Tensor, result_rms: torch.Tensor,
input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at1 = auto_functionalized(torch.ops._C.rms_norm.default,
result=result_rms,
input=input,
weight=weight,
epsilon=1e-5)
at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default,
result=result,
input=at1[1],
scale=scale)
# result
return at2[1]
def rms_replacement_static(result: torch.Tensor, result_rms: torch.Tensor,
input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(torch.ops._C.rms_norm_static_fp8_quant.default,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=1e-5)
# result
return at[1]
def rms_pattern_residual_static(result: torch.Tensor, input: torch.Tensor,
residual: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(torch.ops._C.fused_add_rms_norm.default,
input=input,
residual=residual,
weight=weight,
epsilon=1e-5)
at1 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default,
result=result,
input=at[1],
scale=scale)
# result, residual
return at1[1], at[2]
def rms_replacement_residual_static(result: torch.Tensor, input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor, scale: torch.Tensor):
at = auto_functionalized(
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default,
result=result,
input=input,
residual=residual,
weight=weight,
scale=scale,
epsilon=1e-5)
# result, residual
return at[1], at[2]
def empty_bf16(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
def empty_fp8(*args, **kwargs):
fp8 = torch.float8_e4m3fn
return torch.empty(*args, **kwargs, dtype=fp8, device="cuda")
def empty_fp32(*args, **kwargs):
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
# Utilities for post-processing multi-output matches
RMS_OP = torch.ops._C.rms_norm.default
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
# Returns the first auto_functionalized node with the given op (if it exists)
def find_auto_fn_maybe(nodes: Iterable[torch.fx.Node],
op) -> Optional[torch.fx.Node]:
for node in nodes:
if is_func(node, auto_functionalized) and node.args[0] == op: # noqa
return node
return None
class QuantKey(NamedTuple):
"""
Named tuple for identifying the type of quantization.
dtype: quantized data type
static: static quantization if True, dynamic if False
per_tensor: per-tensor quantization if True, per-token if False
symmetric: symmetric if True, asymmetric if False
"""
dtype: torch.dtype
static: bool
per_tensor: bool = True
symmetric: bool = True
def __str__(self):
return (f"QuantKey({'static' if self.static else 'dynamic'},"
f"{fx.graph.dtype_abbrs[self.dtype]},"
f"{'per_tensor' if self.per_tensor else 'per_token'},"
f"{'a' if not self.symmetric else ''}symmetric)")
# Returns the first auto_functionalized node with the given op
def find_auto_fn(nodes: Iterable[torch.fx.Node], op) -> torch.fx.Node:
node = find_auto_fn_maybe(nodes, op)
assert node is not None, f"Could not find {op} in nodes {nodes}"
return node
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, True, True)
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, True, True)
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, False, True)
QUANT_OPS: Dict[QuantKey, OpOverload] = {
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa
kFp8DynamicTensorSym:
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa
kFp8DynamicTokenSym:
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa
}
# Returns the getitem node that extracts the idx-th element from node
# (if it exists)
def find_getitem_maybe(node: torch.fx.Node,
idx: int) -> Optional[torch.fx.Node]:
for user in node.users:
if is_func(user, operator.getitem) and user.args[1] == idx:
return user
return None
class FusedRMSQuantKey(NamedTuple):
"""
Named tuple for identifying the type of RMSNorm + quant fusion.
quant: type of quantization
fused_add: does the op also perform the residual add
"""
quant: QuantKey
fused_add: bool
def __str__(self):
return (f"FusedQuantKey({self.quant}, with"
f"{'' if self.fused_add else 'out'} residual)")
# Returns the getitem node that extracts the idx-th element from node
def find_getitem(node: torch.fx.Node, idx: int) -> torch.fx.Node:
ret = find_getitem_maybe(node, idx)
assert ret is not None, f"Could not find getitem {idx} in node {node}"
return ret
FUSED_OPS: Dict[FusedRMSQuantKey, OpOverload] = {
FusedRMSQuantKey(kFp8StaticTensorSym, False):
torch.ops._C.rms_norm_static_fp8_quant.default, # noqa
FusedRMSQuantKey(kFp8StaticTensorSym, True):
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa
FusedRMSQuantKey(kFp8DynamicTokenSym, False):
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa
FusedRMSQuantKey(kFp8DynamicTokenSym, True):
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa
}
class QuantMultiOutputMatch(MultiOutputMatch):
def __init__(self, match: pm.Match, quant_op, fused_op):
super().__init__(match)
assert isinstance(quant_op, OpOverload)
assert isinstance(fused_op, OpOverload)
self.QUANT_OP = quant_op # in-place quant op
self.FUSED_OP = fused_op # in-place fused quant op
def insert_fused_node(self, fused_return_mapping: Dict[int, Tuple[fx.Node,
int]],
**kwargs):
"""
This utility function inserts an auto-functionalized node for FUSED_OP.
It also correctly sets its meta value and rebinds the users of the
unfused nodes to use the fused node instead.
:param fused_return_mapping: A dictionary, mapping from getitem indices
of the fused node result to a tuple of the old node and a getitem index.
:param kwargs: kwargs that get directly forwarded to the auto_fn node
Example:
If we want to replace this graph:
_, x1, x2 = auto_fn(op1)
_, y1, y2 = auto_fn(op2)
with
_, x1, y2, x2 = auto_fn(FUSED_OP)
we would call:
insert_fused_node({1: (op1_node, 1), 2: (op2_node, 2), 3: (op1_node, 2)}
Note that the 0th element is None for auto-functionalized in-place ops.
Hence, others appear 1-indexed.
"""
fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs)
indices = fused_return_mapping.keys()
getitem_nodes = self.insert_getitems(fused_node, indices)
# Prepare the meta value, use a list so it's mutable
meta_val = [None] * (max(indices) + 1)
# Iterate through elements of the tuple produced by fused_node
for idx, getitem_node in zip(indices, getitem_nodes):
old_node, old_idx = fused_return_mapping[idx]
# If the old value was never used, the old_getitem might not exist
old_getitem = find_getitem_maybe(old_node, old_idx)
if old_getitem is not None:
# Rebind the users of match getitem nodes to use the new nodes.
# The old nodes will be removed by DCE at the end of the pass.
old_getitem.replace_all_uses_with(getitem_node)
getitem_node.meta["val"] = old_getitem.meta["val"]
# Extract the appropriate meta value
# It is present even if the getitem node does not exist
meta_val[idx] = old_node.meta["val"][old_idx]
# Fix the meta value on the new fused node
fused_node.meta["val"] = tuple(meta_val)
class RMSNormQuantPattern:
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
self.epsilon = epsilon
self.quant_dtype = key.quant.dtype
assert key.quant in QUANT_OPS, \
f"unsupported quantization scheme {key.quant}"
self.QUANT_OP = QUANT_OPS[key.quant]
assert key in FUSED_OPS, \
f"unsupported fused rmsnorm+quant op for {key}"
self.FUSED_OP = FUSED_OPS[key]
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
def __init__(self,
epsilon: float,
quant_dtype: torch.dtype,
symmetric=True):
fused_key = FusedRMSQuantKey(fused_add=False,
quant=QuantKey(dtype=quant_dtype,
static=True,
per_tensor=True,
symmetric=symmetric))
super().__init__(epsilon, fused_key)
def register(self, pm_pass: PatternMatcherPass):
# Cannot use methods, as the self argument affects tracing
def pattern(result: torch.Tensor, result_rms: torch.Tensor,
input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at1 = auto_functionalized(RMS_OP,
result=result_rms,
input=input,
weight=weight,
epsilon=self.epsilon)
at2 = auto_functionalized(self.QUANT_OP,
result=result,
input=at1[1],
scale=scale)
# result
return at2[1]
def replacement(result: torch.Tensor, result_rms: torch.Tensor,
input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon)
# result
return at[1]
inputs = [
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
empty_bf16(5, 4), # result_rms
empty_bf16(5, 4), # input
empty_bf16(1, 5), # weight
empty_fp32(1, 1) # scale
]
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only,
pm_pass)
class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
def __init__(self,
epsilon: float,
quant_dtype: torch.dtype,
symmetric=True):
key = FusedRMSQuantKey(fused_add=True,
quant=QuantKey(dtype=quant_dtype,
static=True,
per_tensor=True,
symmetric=symmetric))
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass,
record_match: Callable[[MultiOutputMatch], bool]):
def pattern(result: torch.Tensor, input: torch.Tensor,
residual: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(RMS_ADD_OP,
input=input,
residual=residual,
weight=weight,
epsilon=self.epsilon)
at1 = auto_functionalized(self.QUANT_OP,
result=result,
input=at[1],
scale=scale)
# result, residual
return at1[1], at[2]
def replacement(result: torch.Tensor, input: torch.Tensor,
residual: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(self.FUSED_OP,
result=result,
input=input,
residual=residual,
weight=weight,
scale=scale,
epsilon=self.epsilon)
# result, residual
return at[1], at[2]
inputs = [
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
empty_bf16(5, 4), # input
empty_bf16(5, 4), # residual
empty_bf16(1, 5), # weight
empty_fp32(1, 1) # scale
]
pm.register_replacement(
pattern,
replacement,
inputs,
pm.fwd_only,
pm_pass,
extra_check=lambda m: record_match(
self.Match(m, self.QUANT_OP, self.FUSED_OP)))
class Match(QuantMultiOutputMatch):
def process(self):
# Find the nodes in the match that we need to rebind
rms_node = self.find_auto_fn(RMS_ADD_OP)
quant_node = self.find_auto_fn(self.QUANT_OP)
assert len(rms_node.users) == 2
assert len(quant_node.users) == 1
# First, insert a new auto_functionalized node for the fused op,
# as well as getitem nodes to extract the result and residual.
# The auto_fn node returns a tuple of (None, result, residual).
#
# The resulting graph looks like this:
# at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa
# result_node_new = at[1]
# residual_node_new = at[2]
with self.inserting_after_match():
# Missing epsilon, scalars cannot be inputs to the pattern
kwargs = self.match.kwargs.copy()
# 0 is always None
fused_return_mapping = {1: (quant_node, 1), 2: (rms_node, 2)}
self.insert_fused_node(fused_return_mapping,
epsilon=rms_node.kwargs["epsilon"],
**kwargs)
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
def __init__(self,
epsilon: float,
quant_dtype: torch.dtype,
per_tensor: bool,
symmetric=True):
key = FusedRMSQuantKey(fused_add=False,
quant=QuantKey(dtype=quant_dtype,
static=False,
per_tensor=per_tensor,
symmetric=symmetric))
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass,
record_match: Callable[[MultiOutputMatch], bool]):
def pattern(result: torch.Tensor, result_rms: torch.Tensor,
input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at1 = auto_functionalized(RMS_OP,
result=result_rms,
input=input,
weight=weight,
epsilon=self.epsilon)
at2 = auto_functionalized(self.QUANT_OP,
result=result,
input=at1[1],
scale=scale,
scale_ub=None)
# result, scale
return at2[1], at2[2]
def replacement(result: torch.Tensor, result_rms: torch.Tensor,
input: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon,
scale_ub=None,
residual=None)
# result, scale
return at[1], at[2]
inputs = [
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
empty_bf16(5, 4), # result_rms
empty_bf16(5, 4), # input
empty_bf16(1, 5), # weight
empty_fp32(1, 1) # scale
]
pm.register_replacement(
pattern,
replacement,
inputs,
pm.fwd_only,
pm_pass,
extra_check=lambda m: record_match(
self.Match(m, self.QUANT_OP, self.FUSED_OP)))
class Match(QuantMultiOutputMatch):
def process(self):
# Find the nodes in the match that we need to rebind
rms_node = self.find_auto_fn(RMS_OP)
quant_node = self.find_auto_fn(self.QUANT_OP)
assert len(rms_node.users) == 1
assert len(quant_node.users) == 2
# First, insert a new auto_functionalized node for the fused op,
# as well as getitem nodes to extract the result and scale.
# The auto_fn node returns a tuple of (None, result, scale).
#
# The resulting graph looks like this:
# at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa
# result_node_new = at[1]
# scale_node_new = at[2]
with self.inserting_after_match():
# Missing epsilon, scalars cannot be inputs to the pattern
kwargs = self.match.kwargs.copy()
del kwargs["result_rms"] # not used in the fused op
fused_return_mapping = {1: (quant_node, 1), 2: (quant_node, 2)}
self.insert_fused_node(
fused_return_mapping,
epsilon=rms_node.kwargs["epsilon"],
scale_ub=None, # not used but required
residual=None, # not used but required
**kwargs)
class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
def __init__(self,
epsilon: float,
quant_dtype: torch.dtype,
per_tensor: bool = True,
symmetric=True):
key = FusedRMSQuantKey(fused_add=True,
quant=QuantKey(dtype=quant_dtype,
static=False,
per_tensor=per_tensor,
symmetric=symmetric))
super().__init__(epsilon, key)
def register(self, pm_pass: PatternMatcherPass,
record_match: Callable[[MultiOutputMatch], bool]):
def pattern(result: torch.Tensor, input: torch.Tensor,
residual: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(RMS_ADD_OP,
input=input,
residual=residual,
weight=weight,
epsilon=self.epsilon)
at1 = auto_functionalized(self.QUANT_OP,
result=result,
input=at[1],
scale=scale,
scale_ub=None)
# result, residual, scale
return at1[1], at[2], at1[2]
def replacement(result: torch.Tensor, input: torch.Tensor,
residual: torch.Tensor, weight: torch.Tensor,
scale: torch.Tensor):
at = auto_functionalized(self.FUSED_OP,
result=result,
input=input,
weight=weight,
scale=scale,
epsilon=self.epsilon,
scale_ub=None,
residual=residual)
# result, residual, scale
return at[1], at[3], at[2]
inputs = [
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
empty_bf16(5, 4), # input
empty_bf16(5, 4), # residual
empty_bf16(1, 5), # weight
empty_fp32(1, 1) # scale
]
pm.register_replacement(
pattern,
replacement,
inputs,
pm.fwd_only,
pm_pass,
extra_check=lambda m: record_match(
self.Match(m, self.QUANT_OP, self.FUSED_OP)))
class Match(QuantMultiOutputMatch):
def process(self):
# Find the nodes in the match that we need to rebind
rms_node = self.find_auto_fn(RMS_ADD_OP)
quant_node = self.find_auto_fn(self.QUANT_OP)
assert len(rms_node.users) == 2
assert len(quant_node.users) == 2
# First, insert a new auto_functionalized node for the fused op,
# as well as getitem nodes to extract result, scale, and residual.
# The auto_fn node returns a tuple (None, result, scale, residual).
#
# The resulting graph looks like this:
# at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa
# result_node_new = at[1]
# scale_node_new = at[2]
# residual_node_new = at[3]
with self.inserting_after_match():
# Missing epsilon, scalars cannot be inputs to the pattern
kwargs = self.match.kwargs.copy()
fused_return_mapping = {
1: (quant_node, 1), # result
2: (quant_node, 2), # scale
3: (rms_node, 2), # residual
}
self.insert_fused_node(
fused_return_mapping,
epsilon=rms_node.kwargs["epsilon"],
scale_ub=None, # not used but required
**kwargs)
class FusionPass(VllmInductorPass):
@ -158,41 +546,39 @@ class FusionPass(VllmInductorPass):
"FusionPass singleton instance already exists"
super().__init__(config)
self.matches: List[Match] = []
self.matches: List[MultiOutputMatch] = []
self.patterns: PatternMatcherPass = PatternMatcherPass(
pass_name="fusion_pass")
# Fuse rms_norm + static_scaled_fp8_quant into
# rms_norm_static_fp8_quant
inputs = [
empty_fp8(5, 4),
empty_bf16(5, 4),
empty_bf16(5, 4),
empty_bf16(1, 5),
empty_fp32(1, 1)
]
register_replacement(rms_pattern_static, rms_replacement_static,
inputs, fwd_only, self.patterns)
for epsilon in [1e-5, 1e-6]:
# Fuse rms_norm + static fp8 quant
RMSNormStaticQuantPattern(epsilon,
FP8_DTYPE).register(self.patterns)
# Fuse fused_add_rms_norm + static_scaled_fp8_quant into
# fused_add_rms_norm_static_fp8_quant
# Because pattern has 2 outputs, we need to manually process the match
# (see process_matches)
inputs = [
empty_fp8(5, 4),
empty_bf16(5, 4),
empty_bf16(5, 4),
empty_bf16(1, 5),
empty_fp32(1, 1)
]
register_replacement(rms_pattern_residual_static,
rms_replacement_residual_static,
inputs,
fwd_only,
self.patterns,
extra_check=lambda m: self.record_match(m))
# Matches for patterns below have 2 or more outputs,
# so we need to process them manually (see process_matches)
def record_match(self, match: Match) -> bool:
# Fuse rms_norm + static fp8 quant
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
self.patterns, self.record_match)
# Fuse rms_norm + dynamic per-token fp8 quant
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE,
per_tensor=False).register(
self.patterns, self.record_match)
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
FusedAddRMSNormDynamicQuantPattern(epsilon,
FP8_DTYPE,
per_tensor=False).register(
self.patterns,
self.record_match)
# WARNING: This is a hack to clear the pattern matcher cache
# and allow multiple values of epsilon.
torch._inductor.pattern_matcher._seen_patterns.clear()
def record_match(self, match: MultiOutputMatch) -> bool:
# Hijack the extra_check to record the match and
# save it for post-processing.
self.matches.append(match)
@ -200,83 +586,20 @@ class FusionPass(VllmInductorPass):
# Return False to prevent automatic replacement.
return False
def process_matches(self, graph: torch.fx.Graph):
def process_matches(self, graph: fx.Graph):
"""
Manually process multi-output matches and replace them with fused nodes.
This is necessary because the automatic replacement for multi-output
matches is broken: https://github.com/pytorch/pytorch/issues/137280
See MultiOutputMatch for more details.
"""
for match in self.matches:
# To avoid use-before-definition errors, insert replacement nodes
# after the last node in the match.
# match.nodes is not guaranteed to be sorted.
# Find the last node in the match.
for last_node_in_match in reversed(graph.nodes):
if last_node_in_match in match.nodes:
break
else:
raise ValueError("No nodes in graph")
# Insert a new auto_functionalized node for the fused operation,
# as well as getitem nodes to extract the result and residual.
# The auto_functionalized node returns a tuple of
# (None, result, residual) - None is the function return value.
# The resulting graph looks like this:
# at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa
# result_node_new = at[1]
# residual_node_new = at[2]
with graph.inserting_after(last_node_in_match):
kwargs = match.kwargs
kwargs["epsilon"] = 1e-5 # Currently hard-coded in RMSNorm
fused_node = graph.call_function(
auto_functionalized,
(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default,
),
kwargs=kwargs)
graph.inserting_after(fused_node)
result_node_new = graph.call_function(operator.getitem,
(fused_node, 1))
residual_node_new = graph.call_function(
operator.getitem, (fused_node, 2))
# Last part of replacement is rebinding the users of nodes in the
# match to use the new nodes.
# Find the nodes in the match that we need to rebind
rms_node = find_auto_fn(match.nodes,
torch.ops._C.fused_add_rms_norm.default)
quant_node = find_auto_fn(
match.nodes, torch.ops._C.static_scaled_fp8_quant.default)
assert len(rms_node.users) == 2
assert len(quant_node.users) == 1
# meta["val"] is used by de-functionalization and has to contain the
# value of the node (tuple of tensors) that would be returned by the
# functionalized node during tracing.
rms_tup = rms_node.meta["val"]
quant_tup = quant_node.meta["val"]
# The result of fused_node must be a tuple with the first element
# None (the function return value) and the remaining elements
# representing the mutated inputs.
fused_tup = (None, quant_tup[1], rms_tup[1], rms_tup[2])
fused_node.meta["val"] = fused_tup
# Find the getitem nodes and replace their uses with the new nodes.
# The old nodes will be removed by DCE at the end of the pass.
find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new)
find_getitem(quant_node, 1).replace_all_uses_with(result_node_new)
match.process()
# Finally, remove matched nodes
graph.eliminate_dead_code()
assert all(node not in graph.nodes for match in self.matches
for node in match.nodes)
for node in match.match.nodes)
def __call__(self, graph: torch.fx.Graph):
def __call__(self, graph: fx.Graph):
self.begin()
self.dump_graph(graph, "before_fusion")

View File

@ -0,0 +1,42 @@
import operator
from typing import Iterable, Optional
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._ops import OpOverload
def is_func(node: fx.Node, target) -> bool:
return node.op == "call_function" and node.target == target
# Returns the first auto_functionalized node with the given op (if it exists)
def find_auto_fn_maybe(nodes: Iterable[fx.Node],
op: OpOverload) -> Optional[fx.Node]:
for node in nodes:
if is_func(node, auto_functionalized) and node.args[0] == op: # noqa
return node
return None
# Returns the first auto_functionalized node with the given op
def find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
node = find_auto_fn_maybe(nodes, op)
assert node is not None, f"Could not find {op} in nodes {nodes}"
return node
# Returns the getitem node that extracts the idx-th element from node
# (if it exists)
def find_getitem_maybe(node: fx.Node, idx: int) -> Optional[fx.Node]:
for user in node.users:
if is_func(user, operator.getitem) and user.args[1] == idx:
return user
return None
# Returns the getitem node that extracts the idx-th element from node
def find_getitem(node: fx.Node, idx: int) -> fx.Node:
ret = find_getitem_maybe(node, idx)
assert ret is not None, f"Could not find getitem {idx} in node {node}"
return ret

View File

@ -0,0 +1,105 @@
import abc
import operator
from abc import abstractmethod
from typing import Iterable, List, Tuple
from torch import fx
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor import pattern_matcher as pm
from torch._ops import OpOverload
from vllm.compilation.fx_utils import find_auto_fn
class MultiOutputMatch(abc.ABC):
"""
This class provides utilities to process multi-output matches and
manually insert replacements.
This is necessary because the automatic replacement for multi-output
matches is broken: https://github.com/pytorch/pytorch/issues/137280
"""
def __init__(self, match: pm.Match):
self.match = match
@abstractmethod
def process(self):
"""
Process a multi-output match and manually insert the replacement.
This method should:
1. Insert the replacement nodes after the last node in the match.
2. Rebind the users of nodes in the match to use the new nodes.
3. Set meta["val"] for de-functionalization.
The result of an auto-functionalized node is a tuple of tensors.
The first element is the return value of the function, usually None.
The remaining elements are the mutated args of the function.
All auto-functionalized nodes must contain a proper meta["val"],
as it is used by de-functionalization. meta["val"] has to contain the
value of the node (tuple of tensors) that would be returned by the
functionalized node during tracing.
Existing nodes in the graph all have this property set, but we have
to set it manually for new nodes we insert.
Example:
# op schema: foo(a: Tensor!, b: Tensor, c: Tensor!) -> None
at = auto_functionalized(torch.ops._C.foo.default, a, b, c)
# at.meta["val"] = (None, a, c)
"""
raise NotImplementedError
@property
def nodes(self) -> List[fx.Node]:
return self.match.nodes
@property
def graph(self) -> fx.Graph:
return self.match.graph
def find_auto_fn(self, op) -> fx.Node:
"""
Find the first auto_functionalized node with the given op in the match.
"""
return find_auto_fn(self.nodes, op)
def inserting_after_match(self):
"""
Insert nodes after the last node in the match.
This is done to avoid use-before-definition errors after inserting
replacement nodes.
"""
# match.nodes is not guaranteed to be sorted.
# Find the last node in the match.
for last_node_in_match in reversed(self.graph.nodes):
if last_node_in_match in self.match.nodes:
break
else:
raise ValueError("No nodes in graph")
return self.graph.inserting_after(last_node_in_match)
def insert_getitems(self, tuple_node: fx.Node,
indices: Iterable[int]) -> Tuple[fx.Node, ...]:
"""
Insert operator.getitem nodes to extract elements from a tuple node.
:param tuple_node: The tuple node to extract elements from.
:param indices: The indices of the elements to extract.
:return: Tuple of the new getitem nodes, corresponding to the indices.
"""
with self.graph.inserting_after(tuple_node):
return tuple(
self.graph.call_function(operator.getitem, (tuple_node, idx))
for idx in indices)
def insert_auto_fn(self, op: OpOverload, kwargs):
"""
Insert an auto_functionalized node with the given op and kwargs.
"""
return self.graph.call_function(auto_functionalized, (op, ),
kwargs=kwargs)

View File

@ -5,7 +5,8 @@ from torch import SymInt
from vllm.logger import init_logger
from .vllm_inductor_pass import VllmInductorPass, is_func
from .fx_utils import is_func
from .vllm_inductor_pass import VllmInductorPass
logger = init_logger(__name__)

View File

@ -16,10 +16,6 @@ from .inductor_pass import InductorPass
logger = init_logger(__name__)
def is_func(node: torch.fx.Node, target) -> bool:
return node.op == "call_function" and node.target == target
class VllmInductorPass(InductorPass):
"""
An inductor pass with access to vLLM PassConfig.