From 30870b4f66414020645608b81dced94d8a99111c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Luka=20Govedi=C4=8D?= Date: Thu, 12 Dec 2024 22:19:23 -0500 Subject: [PATCH] [torch.compile] Dynamic fp8 + rms_norm fusion (#10906) Signed-off-by: luka Co-authored-by: Varun Sundar Rabindranath --- CMakeLists.txt | 3 +- .../fused_kernels/layernorm_rms_benchmarks.py | 173 +++++ csrc/dispatch_utils.h | 14 + csrc/ops.h | 8 + csrc/quantization/fp8/common.cuh | 26 +- ...fused_layernorm_dynamic_per_token_quant.cu | 160 ++++ .../fused_kernels/layernorm_utils.cuh | 327 ++++++++ .../fused_kernels/quant_conversions.cuh | 81 ++ csrc/quantization/vectorization.cuh | 33 + csrc/torch_bindings.cpp | 8 + tests/compile/test_functionalization.py | 21 +- tests/compile/test_fusion.py | 61 +- tests/kernels/test_fused_quant_layernorm.py | 171 +++++ vllm/_custom_ops.py | 20 + vllm/compilation/fix_functionalization.py | 9 +- vllm/compilation/fusion.py | 717 +++++++++++++----- vllm/compilation/fx_utils.py | 42 + vllm/compilation/multi_output_match.py | 105 +++ vllm/compilation/reshapes.py | 3 +- vllm/compilation/vllm_inductor_pass.py | 4 - 20 files changed, 1735 insertions(+), 251 deletions(-) create mode 100644 benchmarks/fused_kernels/layernorm_rms_benchmarks.py create mode 100644 csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu create mode 100644 csrc/quantization/fused_kernels/layernorm_utils.cuh create mode 100644 csrc/quantization/fused_kernels/quant_conversions.cuh create mode 100644 csrc/quantization/vectorization.cuh create mode 100644 tests/kernels/test_fused_quant_layernorm.py create mode 100644 vllm/compilation/fx_utils.py create mode 100644 vllm/compilation/multi_output_match.py diff --git a/CMakeLists.txt b/CMakeLists.txt index c78cdc77..bf19b3d2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -196,6 +196,7 @@ set(VLLM_EXT_SRC "csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/fp8/common.cu" + "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" "csrc/quantization/gguf/gguf_kernel.cu" "csrc/cuda_utils_kernels.cu" "csrc/prepare_inputs/advance_step.cu" @@ -300,7 +301,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # # For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x) # kernels for the remaining archs that are not already built for 3x. - cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS + cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS "7.5;8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}") # subtract out the archs that are already built for 3x list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS}) diff --git a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py new file mode 100644 index 00000000..ef91f9f8 --- /dev/null +++ b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py @@ -0,0 +1,173 @@ +import pickle as pkl +import time +from dataclasses import dataclass +from itertools import product +from typing import Callable, Iterable, List, Optional + +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from tqdm import tqdm + +import vllm._custom_ops as ops +from vllm.model_executor.layers.layernorm import RMSNorm + + +@dataclass +class bench_params_t: + num_tokens: int + hidden_size: int + add_residual: bool + dtype: torch.dtype + + def description(self): + return (f'N {self.num_tokens} ' + f'x D {self.hidden_size} ' + f'x R {self.add_residual} ' + f'x DT {self.dtype}') + + +def get_bench_params() -> List[bench_params_t]: + ## Test Fixtures + NUM_TOKENS = [2**x for x in range(11)] + HIDDEN_SIZES = list(range(1024, 8129, 1024)) + ADD_RESIDUAL = [True, False] + DTYPES = [torch.bfloat16, torch.float] + + combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES) + bench_params = list(map(lambda x: \ + bench_params_t(x[0], x[1], x[2], x[3]), combinations)) + return bench_params + + +# Reference impls +def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, + residual: Optional[torch.Tensor], + quant_dtype: torch.dtype): + # Norm + torch_out = None + if residual is None: + torch_out = rms_norm_layer.forward_cuda(x, residual) + else: + torch_out, _ = rms_norm_layer.forward_cuda(x, residual) + + # Quant + torch_out, _, _ = ops.scaled_int8_quant(torch_out) + + +def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, + residual: Optional[torch.Tensor], + quant_dtype: torch.dtype): + # Norm + torch_out = None + if residual is None: + torch_out = rms_norm_layer.forward_cuda(x, residual) + else: + torch_out, _ = rms_norm_layer.forward_cuda(x, residual) + + # Quant + torch_out, _ = ops.scaled_fp8_quant(torch_out) + + +def fused_impl( + rms_norm_layer: RMSNorm, # this stores the weights + x: torch.Tensor, + residual: Optional[torch.Tensor], + quant_dtype: torch.dtype): + out, _ = ops.rms_norm_dynamic_per_token_quant(x, + rms_norm_layer.weight, + 1e-6, + quant_dtype, + residual=residual) + + +# Bench functions +def bench_fn(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor, + quant_dtype: torch.dtype, label: str, sub_label: str, + fn: Callable, description: str) -> TMeasurement: + + min_run_time = 1 + + globals = { + "rms_norm_layer": rms_norm_layer, + "x": x, + "residual": residual, + "quant_dtype": quant_dtype, + "fn": fn, + } + return TBenchmark.Timer( + stmt="fn(rms_norm_layer, x, residual, quant_dtype)", + globals=globals, + label=label, + sub_label=sub_label, + description=description, + ).blocked_autorange(min_run_time=min_run_time) + +def bench(params: bench_params_t, label: str, sub_label: str) \ + -> Iterable[TMeasurement]: + + # Make inputs + layer = RMSNorm(params.hidden_size, 1e-6).to(dtype=params.dtype) + # Make weights + layer.weight.data.normal_(mean=1.0, std=0.1) + # Make inputs + scale = 1 / params.hidden_size + x = torch.randn(params.num_tokens, + params.hidden_size, + dtype=params.dtype, + device='cuda') * scale + residual = (torch.randn_like(x) * scale).to(device='cuda') \ + if params.add_residual else None + + timers = [] + + # unfused int8 impl. + timers.append( + bench_fn(layer, x, residual, torch.int8, label, sub_label, + unfused_int8_impl, "unfused_int8_impl")) + + # unfused fp8 impl. + timers.append( + bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label, + unfused_fp8_impl, "unfused_fp8_impl")) + + # fused int8 impl. + timers.append( + bench_fn(layer, x, residual, torch.int8, label, sub_label, fused_impl, + "fused_int8_impl")) + + # fused fp8 impl. + timers.append( + bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label, + fused_impl, "fused_fp8_impl")) + + print_timers(timers) + + return timers + + +# launch bench +# runner +def print_timers(timers: Iterable[TMeasurement]): + compare = TBenchmark.Compare(timers) + compare.print() + + +def main(): + torch.set_default_device('cuda') + bench_params = get_bench_params() + + timers = [] + for bp in tqdm(bench_params): + timers.extend( + bench(bp, "rms-norm-dynamic-per-token-quant", bp.description())) + print_timers(timers) + + # pickle all the results + timestamp = int(time.time()) + with open(f"rms_norm_dpt_quant-{timestamp}.pkl", "wb") as f: + pkl.dump(timers, f) + + +if __name__ == '__main__': + main() diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index a634e1c3..03414b7e 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -14,6 +14,20 @@ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) +// TODO(luka/varun): use FP8_TYPE macro after refactoring +#ifndef USE_ROCM + #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) +#else + #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) +#endif + +#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__)) + #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ diff --git a/csrc/ops.h b/csrc/ops.h index ea001190..816b471d 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -66,6 +66,14 @@ void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& weight, torch::Tensor& scale, double epsilon); +void rms_norm_dynamic_per_token_quant(torch::Tensor& out, + torch::Tensor const& input, + torch::Tensor const& weight, + torch::Tensor& scales, + double const epsilon, + std::optional scale_ub, + std::optional residual); + void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox); diff --git a/csrc/quantization/fp8/common.cuh b/csrc/quantization/fp8/common.cuh index d7c0297d..15bd5b6e 100644 --- a/csrc/quantization/fp8/common.cuh +++ b/csrc/quantization/fp8/common.cuh @@ -1,6 +1,9 @@ #pragma once +#include "quantization/vectorization.cuh" + #include +#include #ifndef USE_ROCM #include @@ -15,6 +18,7 @@ using FP8_TYPE = c10::Float8_e4m3fnuz; // issue when running dynamic quantization. Here use 224.0f for rocm. constexpr auto FP8_E4M3_MAX = 224.0f; #endif +constexpr static auto kFp8Type = c10::CppTypeToScalarType::value; namespace vllm { @@ -89,22 +93,6 @@ __global__ void segmented_max_reduction(float* __restrict__ scale, } } -template -struct __align__(8) vec4_t { - scalar_t x; - scalar_t y; - scalar_t z; - scalar_t w; -}; - -typedef struct __align__(4) { - FP8_TYPE x; - FP8_TYPE y; - FP8_TYPE z; - FP8_TYPE w; -} -float8x4_t; - template __device__ float thread_max_vec(scalar_t const* __restrict__ input, int64_t const num_elems, int const tid, @@ -139,10 +127,10 @@ __device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out, float const scale, int64_t const num_elems, int const tid, int const step) { + using float8x4_t = q8x4_t; // Vectorized input/output to better utilize memory bandwidth. - vec4_t const* vectorized_in = - reinterpret_cast const*>(input); - float8x4_t* vectorized_out = reinterpret_cast(out); + auto const* vectorized_in = reinterpret_cast const*>(input); + auto* vectorized_out = reinterpret_cast(out); int64_t const num_vec_elems = num_elems >> 2; diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu new file mode 100644 index 00000000..3c4f183b --- /dev/null +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -0,0 +1,160 @@ + +#include +#include + +#include "../../dispatch_utils.h" +#include "layernorm_utils.cuh" +#include "quant_conversions.cuh" + +namespace vllm { + +template +__device__ void rms_norm_dynamic_per_token_quant_vec( + scalar_out_t* __restrict__ out, // [..., hidden_size] + float* __restrict__ scales, // [num_tokens] + scalar_t const* __restrict__ input, // [..., hidden_size] + scalar_t const* __restrict__ weight, // [hidden_size] + float const* scale_ub, float const var_epsilon, + float const min_scaling_factor, int32_t const hidden_size, + scalar_t* __restrict__ residual = nullptr) { + float rms = 0.0f; + float token_scale = 0.0f; + + // Compute rms + vllm::vectorized::compute_rms( + &rms, input, hidden_size, var_epsilon, residual); + + // Compute scale + vllm::vectorized::compute_dynamic_per_token_scales( + &token_scale, scales, input, weight, rms, scale_ub, min_scaling_factor, + hidden_size, residual); + + // RMS Norm + Quant + if constexpr (std::is_same_v) { + vllm::vectorized::norm_and_quant( + out, input, weight, rms, 1.0f / token_scale, hidden_size, residual); + } else { + // FP8 - Do not invert token_scale for exact match with FBGemm + vllm::vectorized::norm_and_quant( + out, input, weight, rms, token_scale, hidden_size, residual); + } +} + +// RMS norm + quant kernel +template +__global__ void rms_norm_dynamic_per_token_quant_kernel( + scalar_out_t* __restrict__ out, // [..., hidden_size] + float* __restrict__ scales, // [num_tokens] + scalar_t const* __restrict__ input, // [..., hidden_size] + scalar_t const* __restrict__ weight, // [hidden_size] + float const* scale_ub, float const var_epsilon, + float const min_scaling_factor, int32_t const hidden_size, + scalar_t* __restrict__ residual = nullptr) { + // For vectorization, token_input and token_output pointers need to be + // aligned at 8-byte and 4-byte addresses respectively. + bool const can_vectorize = hidden_size % 4 == 0; + + if (can_vectorize) { + return rms_norm_dynamic_per_token_quant_vec( + out, scales, input, weight, scale_ub, var_epsilon, min_scaling_factor, + hidden_size, residual); + } + + float rms = 0.0f; + float token_scale = 0.0f; + + // Compute RMS + vllm::compute_rms(&rms, input, hidden_size, + var_epsilon, residual); + // Compute Scale + vllm::compute_dynamic_per_token_scales( + &token_scale, scales, input, weight, rms, scale_ub, min_scaling_factor, + hidden_size, residual); + + // RMS Norm + Quant + if constexpr (std::is_same_v) { + vllm::norm_and_quant( + out, input, weight, rms, 1.0f / token_scale, hidden_size, residual); + } else { + // FP8 - Do not invert s_token_scale for exact match with FBGemm + vllm::norm_and_quant( + out, input, weight, rms, token_scale, hidden_size, residual); + } +} +} // namespace vllm + +// Residual add + RMS norm + dynamic per token +template +void rms_norm_dynamic_per_token_quant_dispatch( + torch::Tensor& out, // [..., hidden_size] + torch::Tensor const& input, // [..., hidden_size] + torch::Tensor const& weight, // [hidden_size] + torch::Tensor& scales, // [num_tokens] + double const var_epsilon, // Variance epsilon used in norm calculation + std::optional const& scale_ub, + std::optional& residual) { + int32_t hidden_size = input.size(-1); + int32_t num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + const float min_scaling_factor = + out.dtype() == torch::kInt8 + ? std::numeric_limits::epsilon() + : 1.0f / (std::numeric_limits::max() * 512.f); + + if (residual.has_value()) { + VLLM_DISPATCH_QUANT_TYPES( + out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] { + vllm::rms_norm_dynamic_per_token_quant_kernel + <<>>( + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), weight.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, + var_epsilon, min_scaling_factor, hidden_size, + residual->data_ptr()); + }); + + } else { + VLLM_DISPATCH_QUANT_TYPES( + out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] { + vllm::rms_norm_dynamic_per_token_quant_kernel + <<>>( + out.data_ptr(), scales.data_ptr(), + input.data_ptr(), weight.data_ptr(), + scale_ub.has_value() ? scale_ub->data_ptr() : nullptr, + var_epsilon, min_scaling_factor, hidden_size, nullptr); + }); + } +} + +void rms_norm_dynamic_per_token_quant( + torch::Tensor& out, // [..., hidden_size] + torch::Tensor const& input, // [..., hidden_size] + torch::Tensor const& weight, // [hidden_size] + torch::Tensor& scales, // [num_tokens] + double const var_epsilon, // Variance epsilon used in norm calculation + std::optional scale_ub, std::optional residual) { + TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8); + TORCH_CHECK(out.is_contiguous() && input.is_contiguous()); + + if (scale_ub.has_value()) { + TORCH_CHECK(out.dtype() == kFp8Type); + } + TORCH_CHECK(scales.dtype() == torch::kFloat32); + + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "rms_norm_dynamic_per_token_quant_dispatch", [&] { + rms_norm_dynamic_per_token_quant_dispatch( + out, input, weight, scales, var_epsilon, scale_ub, residual); + }); +} diff --git a/csrc/quantization/fused_kernels/layernorm_utils.cuh b/csrc/quantization/fused_kernels/layernorm_utils.cuh new file mode 100644 index 00000000..cec6b54e --- /dev/null +++ b/csrc/quantization/fused_kernels/layernorm_utils.cuh @@ -0,0 +1,327 @@ +#pragma once + +/** + * __device__ layernorm utilities. + */ + +#include "quantization/vectorization.cuh" +#include "quant_conversions.cuh" + +#ifndef USE_ROCM + #include +#else + #include +#endif + +namespace vllm { + +// has_residual must be true, if residual is not a nullptr +template +__device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, + int32_t const hidden_size, float const epsilon, + scalar_t const* __restrict__ residual = nullptr) { + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + // sum of squares + float ss = 0.0f; + + for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float x = static_cast(input[token_offset + i]); + if constexpr (has_residual) { + x += static_cast(residual[token_offset + i]); + } + + ss += x * x; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + ss = BlockReduce(reduceStore).Reduce(ss, cub::Sum{}, blockDim.x); + + __shared__ float s_rms; + if (threadIdx.x == 0) { + s_rms = rsqrtf(ss / hidden_size + epsilon); + } + __syncthreads(); + + *rms = s_rms; +} + +template +__device__ void compute_dynamic_per_token_scales( + float* __restrict__ token_scale, float* __restrict__ all_token_scales, + scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, + float const rms, float const* __restrict__ scale_ub, + float const min_scaling_factor, int32_t const hidden_size, + scalar_t const* __restrict__ residual = nullptr) { + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + ; + constexpr scalar_out_t qmax{std::numeric_limits::max()}; + + float block_absmax_val_maybe = 0.0f; + for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float x = static_cast(input[token_offset + i]); + if constexpr (has_residual) { + x += static_cast(residual[token_offset + i]); + } + + x = static_cast(static_cast(x * rms) * weight[i]); + block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x)); + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + block_absmax_val_maybe = + BlockReduce(reduceStore) + .Reduce(block_absmax_val_maybe, cub::Max{}, blockDim.x); + + __shared__ float s_token_scale; + if (threadIdx.x == 0) { + float scale = 0.0f; + if (scale_ub) { + scale = min(block_absmax_val_maybe, *scale_ub); + } else { + scale = block_absmax_val_maybe; + } + // token scale computation + scale = max(scale / qmax, min_scaling_factor); + s_token_scale = scale; // Shared memory store + all_token_scales[blockIdx.x] = scale; // Global output store + } + __syncthreads(); + + *token_scale = s_token_scale; +} + +template +__device__ void norm_and_quant(scalar_out_t* __restrict__ output, + scalar_t const* __restrict__ input, + scalar_t const* __restrict__ weight, + float const rms, float const scale, + int32_t const hidden_size, + scalar_t* __restrict__ residual = nullptr) { + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + ; + + for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) { + float x = static_cast(input[token_offset + i]); + if constexpr (has_residual) { + x += static_cast(residual[token_offset + i]); + residual[token_offset + i] = static_cast(x); + } + // Norm + x = static_cast(static_cast(x * rms) * weight[i]); + // Quant + output[token_offset + i] = + ScaledQuant::quant_fn(x, scale); + } +} + +namespace vectorized { + +// Compute 1.0/rms(input) +// hidden_size must be a multiple of 4 +template +__device__ void compute_rms(float* rms, scalar_t const* __restrict__ input, + int32_t const hidden_size, float const epsilon, + scalar_t const* __restrict__ residual = nullptr) { + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + + // Vectorized input/output to better utilize memory bandwidth. + vec4_t const* vec_input = + reinterpret_cast const*>(&input[token_offset]); + vec4_t const* vec_residual = nullptr; + if constexpr (has_residual) { + vec_residual = + reinterpret_cast const*>(&residual[token_offset]); + } + + // sum of squares + float ss = 0.0f; + + int32_t const num_vec_elems = hidden_size >> 2; + +#pragma unroll 4 + for (int32_t i = threadIdx.x; i < num_vec_elems; i += blockDim.x) { + vec4_t in = vec_input[i]; + + vec4_t x; + x.x = static_cast(in.x); + x.y = static_cast(in.y); + x.z = static_cast(in.z); + x.w = static_cast(in.w); + if constexpr (has_residual) { + vec4_t r = vec_residual[i]; + x.x += static_cast(r.x); + x.y += static_cast(r.y); + x.z += static_cast(r.z); + x.w += static_cast(r.w); + } + + ss += x.x * x.x; + ss += x.y * x.y; + ss += x.z * x.z; + ss += x.w * x.w; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + ss = BlockReduce(reduceStore).Reduce(ss, cub::Sum{}, blockDim.x); + + __shared__ float s_rms; + if (threadIdx.x == 0) { + s_rms = rsqrtf(ss / hidden_size + epsilon); + } + __syncthreads(); + + *rms = s_rms; +} + +// Vectorized version of vllm::compute_dynamic_per_token_scales +// hidden_size must be a multiple of 4 +template +__device__ void compute_dynamic_per_token_scales( + float* __restrict__ token_scale, float* __restrict__ all_token_scales, + scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight, + float const rms, float const* __restrict__ scale_ub, + float const min_scaling_factor, int32_t const hidden_size, + scalar_t const* __restrict__ residual = nullptr) { + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + ; + + // Vectorized input/weight/residual to better utilize memory bandwidth. + vec4_t const* vec_input = + reinterpret_cast const*>(&input[token_offset]); + vec4_t const* vec_weight = + reinterpret_cast const*>(weight); + vec4_t const* vec_residual = nullptr; + if constexpr (has_residual) { + vec_residual = + reinterpret_cast const*>(&residual[token_offset]); + } + + constexpr scalar_out_t qmax{std::numeric_limits::max()}; + + int32_t const num_vec_elems = hidden_size >> 2; + float block_absmax_val_maybe = 0.0f; + +#pragma unroll 4 + for (int32_t i = threadIdx.x; i < num_vec_elems; i += blockDim.x) { + vec4_t in = vec_input[i]; + vec4_t const w = vec_weight[i]; + + vec4_t x; + x.x = static_cast(in.x); + x.y = static_cast(in.y); + x.z = static_cast(in.z); + x.w = static_cast(in.w); + if constexpr (has_residual) { + vec4_t r = vec_residual[i]; + x.x += static_cast(r.x); + x.y += static_cast(r.y); + x.z += static_cast(r.z); + x.w += static_cast(r.w); + } + + block_absmax_val_maybe = fmaxf( + block_absmax_val_maybe, fabs(static_cast(x.x * rms) * w.x)); + block_absmax_val_maybe = fmaxf( + block_absmax_val_maybe, fabs(static_cast(x.y * rms) * w.y)); + block_absmax_val_maybe = fmaxf( + block_absmax_val_maybe, fabs(static_cast(x.z * rms) * w.z)); + block_absmax_val_maybe = fmaxf( + block_absmax_val_maybe, fabs(static_cast(x.w * rms) * w.w)); + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + block_absmax_val_maybe = + BlockReduce(reduceStore) + .Reduce(block_absmax_val_maybe, cub::Max{}, blockDim.x); + + __shared__ float s_token_scale; + if (threadIdx.x == 0) { + float scale = 0.0f; + if (scale_ub) { + scale = min(block_absmax_val_maybe, *scale_ub); + } else { + scale = block_absmax_val_maybe; + } + // token scale computation + scale = max(scale / qmax, min_scaling_factor); + s_token_scale = scale; // shared memory store + all_token_scales[blockIdx.x] = scale; // global output store + } + __syncthreads(); + + *token_scale = s_token_scale; +} + +// hidden_size must be a multiple of 4 +template +__device__ void norm_and_quant(scalar_out_t* __restrict__ output, + scalar_t const* __restrict__ input, + scalar_t const* __restrict__ weight, + float const rms, float const scale, + int32_t const hidden_size, + scalar_t* __restrict__ residual = nullptr) { + int64_t const token_offset = blockIdx.x * static_cast(hidden_size); + ; + + // Vectorized input/output/weight/residual to better utilize memory bandwidth. + vec4_t const* vec_input = + reinterpret_cast const*>(&input[token_offset]); + vec4_t const* vec_weight = + reinterpret_cast const*>(weight); + q8x4_t* vec_output = + reinterpret_cast*>(&output[token_offset]); + vec4_t* vec_residual = nullptr; + if constexpr (has_residual) { + vec_residual = reinterpret_cast*>(&residual[token_offset]); + } + + int32_t const num_vec_elems = hidden_size >> 2; + +// TODO(luka/varun) extract into type-agnostic vectorized quant function to +// replace scaled_fp8_conversion_vec +#pragma unroll 4 + for (int32_t i = threadIdx.x; i < num_vec_elems; i += blockDim.x) { + vec4_t const in = vec_input[i]; + vec4_t const w = vec_weight[i]; + + vec4_t x; + x.x = static_cast(in.x); + x.y = static_cast(in.y); + x.z = static_cast(in.z); + x.w = static_cast(in.w); + if constexpr (has_residual) { + vec4_t r = vec_residual[i]; + x.x += static_cast(r.x); + x.y += static_cast(r.y); + x.z += static_cast(r.z); + x.w += static_cast(r.w); + // Update residual + r.x = static_cast(x.x); + r.y = static_cast(x.y); + r.z = static_cast(x.z); + r.w = static_cast(x.w); + vec_residual[i] = r; + } + + q8x4_t out; + out.x = ScaledQuant::quant_fn( + static_cast(x.x * rms) * w.x, scale); + out.y = ScaledQuant::quant_fn( + static_cast(x.y * rms) * w.y, scale); + out.z = ScaledQuant::quant_fn( + static_cast(x.z * rms) * w.z, scale); + out.w = ScaledQuant::quant_fn( + static_cast(x.w * rms) * w.w, scale); + vec_output[i] = out; + } +} + +} // namespace vectorized + +} // namespace vllm diff --git a/csrc/quantization/fused_kernels/quant_conversions.cuh b/csrc/quantization/fused_kernels/quant_conversions.cuh new file mode 100644 index 00000000..f8a98722 --- /dev/null +++ b/csrc/quantization/fused_kernels/quant_conversions.cuh @@ -0,0 +1,81 @@ +#pragma once + +/** + * __device__ helper functions to deal with float -> quant datatype conversion + */ + +#include "quantization/vectorization.cuh" +// TODO(luka/varun):refactor common.cuh to use this file instead +#include "quantization/fp8/common.cuh" + +namespace vllm { + +// TODO(luka/varun): combine into common utilities for int8 +// (with int8_quant_kernels.cu) +static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) { +#ifdef USE_ROCM + static const float i8_min = + static_cast(std::numeric_limits::min()); + static const float i8_max = + static_cast(std::numeric_limits::max()); + // round + float dst = std::nearbyint(x); + // saturate + dst = std::clamp(dst, i8_min, i8_max); + return static_cast(dst); +#else + // CUDA path + uint32_t dst; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +#endif +} + +static __device__ __forceinline__ FP8_TYPE float_to_fp8(float const x) { + float const r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); + return static_cast(r); +} + +template +struct ScaledQuant; + +template +struct ScaledQuant< + quant_type_t, is_scale_inverted, + typename std::enable_if_t>> { + static __device__ __forceinline__ quant_type_t quant_fn(float const x, + float const scale) { + if constexpr (is_scale_inverted) { + return float_to_int8_rn(x * scale); + } else { + return float_to_int8_rn(x / scale); + } + } +}; + +template +struct ScaledQuant< + quant_type_t, is_scale_inverted, + typename std::enable_if_t>> { + static __device__ __forceinline__ quant_type_t quant_fn(float const x, + float const scale) { + if constexpr (is_scale_inverted) { + return float_to_fp8(x * scale); + } else { + return float_to_fp8(x / scale); + } + } +}; + +template +__device__ void scaled_quant_conversion(quant_type_t* __restrict__ output, + scalar_t const* __restrict__ input, + float const scale, int const tid, + int const num_elements, + int const step) { + for (int i = tid; i < num_elements; i += step) { + output[i] = ScaledQuant(input[i], scale); + } +} + +} // namespace vllm diff --git a/csrc/quantization/vectorization.cuh b/csrc/quantization/vectorization.cuh new file mode 100644 index 00000000..44c99913 --- /dev/null +++ b/csrc/quantization/vectorization.cuh @@ -0,0 +1,33 @@ +#pragma once +/** + * __device__ datatypes vectorized by 4 + */ + +// Include both AMD and NVIDIA fp8 types to avoid circular import +// TODO(luka/varun) use FP8_TYPE instead after refactoring +#include +#include + +namespace vllm { + +// Vectorization containers +template +struct __align__(8) vec4_t { + scalar_t x; + scalar_t y; + scalar_t z; + scalar_t w; +}; + +template +struct __align__(4) q8x4_t { + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v); + quant_type_t x; + quant_type_t y; + quant_type_t z; + quant_type_t w; +}; + +} // namespace vllm diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 4e64b9c9..1ffab148 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -128,6 +128,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA, &fused_add_rms_norm_static_fp8_quant); + // Fused Layernorm + Quant kernels + ops.def( + "rms_norm_dynamic_per_token_quant(Tensor! result, Tensor input, " + "Tensor weight, Tensor! scale, float epsilon, " + "Tensor? scale_ub, Tensor!? residual) -> ()"); + ops.impl("rms_norm_dynamic_per_token_quant", torch::kCUDA, + &rms_norm_dynamic_per_token_quant); + // Rotary embedding // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. ops.def( diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index 50361890..ea3aaee9 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -4,10 +4,10 @@ import torch import vllm.envs as envs from vllm import LLM, SamplingParams from vllm.compilation.fix_functionalization import FixFunctionalizationPass -from vllm.compilation.fusion import (FusionPass, find_auto_fn, - find_auto_fn_maybe) +from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey, + kFp8DynamicTokenSym, kFp8StaticTensorSym) +from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.reshapes import RedundantReshapesPass -from vllm.compilation.vllm_inductor_pass import is_func from vllm.config import CompilationConfig from .backend import TestBackend @@ -35,12 +35,16 @@ prompts = [ ] -@pytest.mark.parametrize("model", - ["nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"]) +@pytest.mark.parametrize( + "model, quant_key", + [("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e", kFp8StaticTensorSym), + ("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8_DYNAMIC-e2e", + kFp8DynamicTokenSym)]) @pytest.mark.parametrize("do_fusion", [True, False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA") -def test_fix_functionalization(model: str, do_fusion: bool): +def test_fix_functionalization(model: str, quant_key: QuantKey, + do_fusion: bool): torch.set_default_device("cuda") config = CompilationConfig.PassConfig(enable_fusion=do_fusion, @@ -78,8 +82,9 @@ def test_fix_functionalization(model: str, do_fusion: bool): # OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion, # and replaced by fused quantized ops in RMS_QUANT_OPS. - ops = OPS_IN_MODEL + (RMS_QUANT_OPS["static_fp8"] - if do_fusion else [RMS_OP]) + rms_ops = [FUSED_OPS[(quant_key, True)], FUSED_OPS[(quant_key, False)] + ] if do_fusion else [RMS_OP] + ops = OPS_IN_MODEL + rms_ops for op in ops: find_auto_fn(backend_no_func.graph_post_pass.nodes, op) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index f92ec8d0..b4266a4a 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -3,8 +3,9 @@ import torch from compressed_tensors.quantization import FP8_DTYPE import vllm.envs as envs -from vllm.compilation.fusion import (FusionPass, find_auto_fn, - find_auto_fn_maybe) +from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey, + FusionPass, QuantKey) +from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe from vllm.compilation.reshapes import RedundantReshapesPass from vllm.config import CompilationConfig from vllm.model_executor.layers.layernorm import RMSNorm @@ -16,24 +17,37 @@ from .backend import TestBackend class TestModel(torch.nn.Module): - def __init__(self, hidden_size: int, eps: float, *args, **kwargs): + def __init__(self, hidden_size: int, eps: float, static: bool, *args, + **kwargs): super().__init__(*args, **kwargs) self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)] - self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(4)] + self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] + if static: + self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)] + else: + self.scale = [None for _ in range(2)] self.w = [ torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() for _ in range(2) ] def forward(self, x): - resid = torch.relu(x) + resid = torch.sqrt(x) y = self.norm[0](x) - x2 = apply_fp8_linear(y, self.w[0], self.scale[0], self.scale[1]) + x2 = apply_fp8_linear(y, + self.w[0], + self.wscale[0], + self.scale[0], + use_per_token_if_dynamic=True) # make sure resid is used for replacement to work y2, resid = self.norm[1](x2, resid) - x3 = apply_fp8_linear(y2, self.w[1], self.scale[2], self.scale[3]) + x3 = apply_fp8_linear(y2, + self.w[1], + self.wscale[1], + self.scale[1], + use_per_token_if_dynamic=True) y3, resid = self.norm[2](x3, resid) # use resid here return y3 @@ -42,14 +56,13 @@ class TestModel(torch.nn.Module): @pytest.mark.parametrize("hidden_size", [64, 3392, 4096]) @pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049]) @pytest.mark.parametrize("eps", [1e-5, 1e-6]) +@pytest.mark.parametrize("static", [True, False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", reason="Only test on CUDA") -def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps): +def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static): torch.set_default_device("cuda") - torch.set_default_dtype(torch.float16) - - if eps != 1e-5: - pytest.skip("Only test eps=1e-5 for now") + torch.set_default_dtype(dtype) + torch.manual_seed(1) # Reshape pass is needed for the fusion pass to work config = CompilationConfig.PassConfig(enable_fusion=True, @@ -58,7 +71,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps): fusion_pass = FusionPass.instance(config) backend = TestBackend(reshape_pass, fusion_pass) - model = TestModel(hidden_size, eps) + model = TestModel(hidden_size, eps, static) # First dimension dynamic x = torch.rand(num_tokens, hidden_size) @@ -69,16 +82,28 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps): model2 = torch.compile(model, backend=backend) result2 = model2(x) - # Check that it gives the same answer - torch.testing.assert_close(result, result2, atol=1e-3, rtol=1e-3) + # Higher tol for dynamic, even higher for bfloat16 + if static: + ATOL, RTOL = (1e-3, 1e-3) + elif dtype == torch.float16: + ATOL, RTOL = (2e-3, 2e-3) + else: + ATOL, RTOL = (1e-2, 1e-2) + + torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL) # Check substitution worked pre_nodes = backend.graph_pre_pass.nodes post_nodes = backend.graph_post_pass.nodes - rms_quant = torch.ops._C.rms_norm_static_fp8_quant.default - add_rms_quant = torch.ops._C.fused_add_rms_norm_static_fp8_quant.default - fp8_quant = torch.ops._C.static_scaled_fp8_quant.default + # static is per-tensor, dynamic is per-token + key = QuantKey(dtype=FP8_DTYPE, + static=static, + per_tensor=static, + symmetric=True) + rms_quant = FUSED_OPS[FusedRMSQuantKey(key, False)] + add_rms_quant = FUSED_OPS[FusedRMSQuantKey(key, True)] + fp8_quant = QUANT_OPS[key] # In pre-nodes, fp8 quant should be present and fused kernels should not assert find_auto_fn_maybe(pre_nodes, rms_quant) is None diff --git a/tests/kernels/test_fused_quant_layernorm.py b/tests/kernels/test_fused_quant_layernorm.py new file mode 100644 index 00000000..baf8d73f --- /dev/null +++ b/tests/kernels/test_fused_quant_layernorm.py @@ -0,0 +1,171 @@ +from typing import Optional, Tuple, Union + +import pytest +import torch + +import vllm._custom_ops as ops +from tests.kernels.utils import opcheck +from vllm.model_executor.layers.layernorm import RMSNorm + +DTYPES = [torch.bfloat16, torch.float] +QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn] +VEC_HIDDEN_SIZES = range(1024, 1030) +# Avoid combinatorial explosion with full Cartesian product +NUM_TOKENS_HIDDEN_SIZES = [ + *[(1, i) for i in [1, 64, *VEC_HIDDEN_SIZES, 5120, 5137]], + *[(83, i) for i in [1, 1033, 2048, 5120]], + *[(2048, i) for i in [1, 64, *VEC_HIDDEN_SIZES, 5137]], + *[(4096, i) for i in [1, 64, 5137]], +] + +ADD_RESIDUAL = [False, True] +SCALE_UBS = [True, False] +SEEDS = [0] +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + +EPS = 1e-6 + +## Helpers + + +def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor: + return torch.as_tensor(x, dtype=torch.float32, device='cuda') + + +def ref_rms_norm(rms_norm_layer: RMSNorm, + x: torch.Tensor, + residual: Optional[torch.Tensor]) \ + -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if residual is not None: + residual = residual.clone() + out, residual = rms_norm_layer.forward_native(x, residual) + else: + out = rms_norm_layer.forward_native(x) + + return out, residual + + +def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm, + x: torch.Tensor, + quant_dtype: torch.dtype, + residual: Optional[torch.Tensor], + scale_ub: Optional[torch.Tensor]) \ + -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + if scale_ub is not None: + assert quant_dtype == torch.float8_e4m3fn + + # Norm + torch_out, residual = ref_rms_norm(rms_norm_layer, x, residual) + + # Quant + if quant_dtype == torch.float8_e4m3fn: + torch_out, scales = ops.scaled_fp8_quant(torch_out, + scale_ub=scale_ub, + use_per_token_if_dynamic=True) + else: + assert quant_dtype == torch.int8 + torch_out, scales = ops.scaled_int8_quant(torch_out) + + return torch_out, scales, residual + + +def ref_impl(rms_norm_layer: RMSNorm, + x: torch.Tensor, + quant_dtype: torch.dtype, + residual: Optional[torch.Tensor], + scale_ub: Optional[torch.Tensor]) \ + -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + return ref_dynamic_per_token_quant(rms_norm_layer, x, quant_dtype, + residual, scale_ub) + + +def ops_dynamic_per_token_quant(weight: torch.Tensor, + x: torch.Tensor, + quant_dtype: torch.dtype, + residual: Optional[torch.Tensor], + scale_ub: Optional[torch.Tensor]) \ + -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + if residual is not None: + residual = residual.clone() + out, scales = ops.rms_norm_dynamic_per_token_quant(x, weight, EPS, + quant_dtype, scale_ub, + residual) + return out, scales, residual + + +def ops_impl(weight: torch.Tensor, + x: torch.Tensor, + quant_dtype: torch.dtype, + residual: Optional[torch.Tensor], + scale_ub: Optional[torch.Tensor]) \ + -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual, + scale_ub) + + +@pytest.mark.parametrize("num_tokens, hidden_size", NUM_TOKENS_HIDDEN_SIZES) +@pytest.mark.parametrize("add_residual", ADD_RESIDUAL) +@pytest.mark.parametrize("scale_ub", SCALE_UBS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("quant_dtype", QUANT_DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_rms_norm( + num_tokens: int, + hidden_size: int, + add_residual: bool, + scale_ub: bool, + dtype: torch.dtype, + quant_dtype: torch.dtype, + seed: int, + device: str, +) -> None: + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + + if scale_ub is not None and quant_dtype != torch.float8_e4m3fn: + # skip + return + + layer = RMSNorm(hidden_size, EPS).to(dtype=dtype) + + # Make weights + layer.weight.data.normal_(mean=1.0, std=0.1) + + # Make inputs + scale = 1 / (hidden_size) + x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale + residual = torch.randn_like(x) * scale if add_residual else None + if scale_ub is not None: + rms_x, _ = ref_rms_norm(layer, x, residual) + scale_ub = torch.mean(rms_x).to(dtype=torch.float32, device='cuda') + + ref_out, ref_scales, ref_residual = \ + ref_impl(layer, x, quant_dtype, residual, scale_ub) + ops_out, ops_scales, ops_residual = \ + ops_impl(layer.weight, x, quant_dtype, residual, scale_ub) + + assert ref_out.dtype == quant_dtype + assert ops_out.dtype == quant_dtype + assert torch.allclose(ref_scales, ops_scales) + if quant_dtype == torch.int8: + # big atol to account for round-off errors. + assert torch.allclose(ref_out, ops_out, atol=1) + else: + assert torch.allclose(ref_out.to(dtype=torch.float32), + ops_out.to(dtype=torch.float32)) + if add_residual: + assert torch.allclose(ref_residual, ops_residual) + + output = torch.empty_like(x, dtype=quant_dtype) + scales = torch.empty((x.numel() // x.shape[-1], 1), + device=x.device, + dtype=torch.float32) + + opcheck(torch.ops._C.rms_norm_dynamic_per_token_quant, + (output, x, layer.weight, scales, 1e-5, scale_ub, residual)) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index c192c9a7..d6002630 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -249,6 +249,26 @@ def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int, block_table_bound) +# fused quant layer norm ops +def rms_norm_dynamic_per_token_quant( + input: torch.Tensor, + weight: torch.Tensor, + epsilon: float, + quant_dtype: torch.dtype, + scale_ub: Optional[torch.Tensor] = None, + residual: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + output = torch.empty_like(input, dtype=quant_dtype) + scales = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + + torch.ops._C.rms_norm_dynamic_per_token_quant(output, input, weight, + scales, epsilon, scale_ub, + residual) + return output, scales + + # quantization ops # awq def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index 3584cc36..e15d7b31 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -6,7 +6,8 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized from vllm.logger import init_logger -from .vllm_inductor_pass import VllmInductorPass, is_func +from .fx_utils import is_func +from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) @@ -53,14 +54,16 @@ class FixFunctionalizationPass(VllmInductorPass): self.insert_defunctionalized(graph, node) self._remove(node) - # These 2 replacements avoid the most copies for LLaMa. + # rms_norm replacements avoid the most copies for LLaMa. elif at_target == torch.ops._C.fused_add_rms_norm.default: mutated_args = {1: 'input', 2: 'residual'} self.defunctionalize(graph, node, mutated_args) elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501 mutated_args = {1: 'result', 2: 'residual'} self.defunctionalize(graph, node, mutated_args) - + elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501 + mutated_args = {1: 'result', 2: 'scale', 3: 'residual'} + self.defunctionalize(graph, node, mutated_args) elif at_target in [ torch.ops._C.rms_norm.default, torch.ops._C.rms_norm_static_fp8_quant.default diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index 5efa410f..cde27bd1 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -1,129 +1,517 @@ -import operator -from typing import Iterable, List, Optional +from typing import Callable, Dict, List, NamedTuple, Optional, Tuple import torch +import torch._inductor.pattern_matcher as pm +# TODO(luka) use vllm.utils once #10836 landed +from compressed_tensors.quantization import FP8_DTYPE +from torch import fx from torch._higher_order_ops.auto_functionalize import auto_functionalized -from torch._inductor.pattern_matcher import (Match, PatternMatcherPass, - fwd_only, register_replacement) +from torch._inductor.pattern_matcher import PatternMatcherPass +from torch._ops import OpOverload from vllm.config import CompilationConfig from vllm.logger import init_logger -from .vllm_inductor_pass import VllmInductorPass, is_func +from .fx_utils import find_getitem_maybe +from .multi_output_match import MultiOutputMatch +from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) -def rms_pattern_static(result: torch.Tensor, result_rms: torch.Tensor, - input: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): - at1 = auto_functionalized(torch.ops._C.rms_norm.default, - result=result_rms, - input=input, - weight=weight, - epsilon=1e-5) - at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default, - result=result, - input=at1[1], - scale=scale) - - # result - return at2[1] - - -def rms_replacement_static(result: torch.Tensor, result_rms: torch.Tensor, - input: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): - at = auto_functionalized(torch.ops._C.rms_norm_static_fp8_quant.default, - result=result, - input=input, - weight=weight, - scale=scale, - epsilon=1e-5) - - # result - return at[1] - - -def rms_pattern_residual_static(result: torch.Tensor, input: torch.Tensor, - residual: torch.Tensor, weight: torch.Tensor, - scale: torch.Tensor): - at = auto_functionalized(torch.ops._C.fused_add_rms_norm.default, - input=input, - residual=residual, - weight=weight, - epsilon=1e-5) - at1 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default, - result=result, - input=at[1], - scale=scale) - - # result, residual - return at1[1], at[2] - - -def rms_replacement_residual_static(result: torch.Tensor, input: torch.Tensor, - residual: torch.Tensor, - weight: torch.Tensor, scale: torch.Tensor): - at = auto_functionalized( - torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, - result=result, - input=input, - residual=residual, - weight=weight, - scale=scale, - epsilon=1e-5) - # result, residual - return at[1], at[2] - - def empty_bf16(*args, **kwargs): return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") -def empty_fp8(*args, **kwargs): - fp8 = torch.float8_e4m3fn - return torch.empty(*args, **kwargs, dtype=fp8, device="cuda") - - def empty_fp32(*args, **kwargs): return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda") -# Utilities for post-processing multi-output matches +RMS_OP = torch.ops._C.rms_norm.default +RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default -# Returns the first auto_functionalized node with the given op (if it exists) -def find_auto_fn_maybe(nodes: Iterable[torch.fx.Node], - op) -> Optional[torch.fx.Node]: - for node in nodes: - if is_func(node, auto_functionalized) and node.args[0] == op: # noqa - return node - return None +class QuantKey(NamedTuple): + """ + Named tuple for identifying the type of quantization. + dtype: quantized data type + static: static quantization if True, dynamic if False + per_tensor: per-tensor quantization if True, per-token if False + symmetric: symmetric if True, asymmetric if False + """ + dtype: torch.dtype + static: bool + per_tensor: bool = True + symmetric: bool = True + + def __str__(self): + return (f"QuantKey({'static' if self.static else 'dynamic'}," + f"{fx.graph.dtype_abbrs[self.dtype]}," + f"{'per_tensor' if self.per_tensor else 'per_token'}," + f"{'a' if not self.symmetric else ''}symmetric)") -# Returns the first auto_functionalized node with the given op -def find_auto_fn(nodes: Iterable[torch.fx.Node], op) -> torch.fx.Node: - node = find_auto_fn_maybe(nodes, op) - assert node is not None, f"Could not find {op} in nodes {nodes}" - return node +kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, True, True) +kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, True, True) +kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, False, True) + +QUANT_OPS: Dict[QuantKey, OpOverload] = { + kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa + kFp8DynamicTensorSym: + torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa + kFp8DynamicTokenSym: + torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa +} -# Returns the getitem node that extracts the idx-th element from node -# (if it exists) -def find_getitem_maybe(node: torch.fx.Node, - idx: int) -> Optional[torch.fx.Node]: - for user in node.users: - if is_func(user, operator.getitem) and user.args[1] == idx: - return user - return None +class FusedRMSQuantKey(NamedTuple): + """ + Named tuple for identifying the type of RMSNorm + quant fusion. + quant: type of quantization + fused_add: does the op also perform the residual add + """ + quant: QuantKey + fused_add: bool + + def __str__(self): + return (f"FusedQuantKey({self.quant}, with" + f"{'' if self.fused_add else 'out'} residual)") -# Returns the getitem node that extracts the idx-th element from node -def find_getitem(node: torch.fx.Node, idx: int) -> torch.fx.Node: - ret = find_getitem_maybe(node, idx) - assert ret is not None, f"Could not find getitem {idx} in node {node}" - return ret +FUSED_OPS: Dict[FusedRMSQuantKey, OpOverload] = { + FusedRMSQuantKey(kFp8StaticTensorSym, False): + torch.ops._C.rms_norm_static_fp8_quant.default, # noqa + FusedRMSQuantKey(kFp8StaticTensorSym, True): + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa + FusedRMSQuantKey(kFp8DynamicTokenSym, False): + torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa + FusedRMSQuantKey(kFp8DynamicTokenSym, True): + torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa +} + + +class QuantMultiOutputMatch(MultiOutputMatch): + + def __init__(self, match: pm.Match, quant_op, fused_op): + super().__init__(match) + assert isinstance(quant_op, OpOverload) + assert isinstance(fused_op, OpOverload) + self.QUANT_OP = quant_op # in-place quant op + self.FUSED_OP = fused_op # in-place fused quant op + + def insert_fused_node(self, fused_return_mapping: Dict[int, Tuple[fx.Node, + int]], + **kwargs): + """ + This utility function inserts an auto-functionalized node for FUSED_OP. + It also correctly sets its meta value and rebinds the users of the + unfused nodes to use the fused node instead. + + :param fused_return_mapping: A dictionary, mapping from getitem indices + of the fused node result to a tuple of the old node and a getitem index. + :param kwargs: kwargs that get directly forwarded to the auto_fn node + + Example: + If we want to replace this graph: + _, x1, x2 = auto_fn(op1) + _, y1, y2 = auto_fn(op2) + + with + _, x1, y2, x2 = auto_fn(FUSED_OP) + + we would call: + insert_fused_node({1: (op1_node, 1), 2: (op2_node, 2), 3: (op1_node, 2)} + + Note that the 0th element is None for auto-functionalized in-place ops. + Hence, others appear 1-indexed. + """ + fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs) + indices = fused_return_mapping.keys() + getitem_nodes = self.insert_getitems(fused_node, indices) + + # Prepare the meta value, use a list so it's mutable + meta_val = [None] * (max(indices) + 1) + + # Iterate through elements of the tuple produced by fused_node + for idx, getitem_node in zip(indices, getitem_nodes): + old_node, old_idx = fused_return_mapping[idx] + + # If the old value was never used, the old_getitem might not exist + old_getitem = find_getitem_maybe(old_node, old_idx) + if old_getitem is not None: + # Rebind the users of match getitem nodes to use the new nodes. + # The old nodes will be removed by DCE at the end of the pass. + old_getitem.replace_all_uses_with(getitem_node) + getitem_node.meta["val"] = old_getitem.meta["val"] + + # Extract the appropriate meta value + # It is present even if the getitem node does not exist + meta_val[idx] = old_node.meta["val"][old_idx] + + # Fix the meta value on the new fused node + fused_node.meta["val"] = tuple(meta_val) + + +class RMSNormQuantPattern: + + def __init__(self, epsilon: float, key: FusedRMSQuantKey): + self.epsilon = epsilon + self.quant_dtype = key.quant.dtype + + assert key.quant in QUANT_OPS, \ + f"unsupported quantization scheme {key.quant}" + self.QUANT_OP = QUANT_OPS[key.quant] + + assert key in FUSED_OPS, \ + f"unsupported fused rmsnorm+quant op for {key}" + self.FUSED_OP = FUSED_OPS[key] + + +class RMSNormStaticQuantPattern(RMSNormQuantPattern): + + def __init__(self, + epsilon: float, + quant_dtype: torch.dtype, + symmetric=True): + fused_key = FusedRMSQuantKey(fused_add=False, + quant=QuantKey(dtype=quant_dtype, + static=True, + per_tensor=True, + symmetric=symmetric)) + super().__init__(epsilon, fused_key) + + def register(self, pm_pass: PatternMatcherPass): + # Cannot use methods, as the self argument affects tracing + def pattern(result: torch.Tensor, result_rms: torch.Tensor, + input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at1 = auto_functionalized(RMS_OP, + result=result_rms, + input=input, + weight=weight, + epsilon=self.epsilon) + at2 = auto_functionalized(self.QUANT_OP, + result=result, + input=at1[1], + scale=scale) + + # result + return at2[1] + + def replacement(result: torch.Tensor, result_rms: torch.Tensor, + input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon) + + # result + return at[1] + + inputs = [ + torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result + empty_bf16(5, 4), # result_rms + empty_bf16(5, 4), # input + empty_bf16(1, 5), # weight + empty_fp32(1, 1) # scale + ] + + pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, + pm_pass) + + +class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern): + + def __init__(self, + epsilon: float, + quant_dtype: torch.dtype, + symmetric=True): + key = FusedRMSQuantKey(fused_add=True, + quant=QuantKey(dtype=quant_dtype, + static=True, + per_tensor=True, + symmetric=symmetric)) + super().__init__(epsilon, key) + + def register(self, pm_pass: PatternMatcherPass, + record_match: Callable[[MultiOutputMatch], bool]): + + def pattern(result: torch.Tensor, input: torch.Tensor, + residual: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(RMS_ADD_OP, + input=input, + residual=residual, + weight=weight, + epsilon=self.epsilon) + at1 = auto_functionalized(self.QUANT_OP, + result=result, + input=at[1], + scale=scale) + + # result, residual + return at1[1], at[2] + + def replacement(result: torch.Tensor, input: torch.Tensor, + residual: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + residual=residual, + weight=weight, + scale=scale, + epsilon=self.epsilon) + + # result, residual + return at[1], at[2] + + inputs = [ + torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result + empty_bf16(5, 4), # input + empty_bf16(5, 4), # residual + empty_bf16(1, 5), # weight + empty_fp32(1, 1) # scale + ] + + pm.register_replacement( + pattern, + replacement, + inputs, + pm.fwd_only, + pm_pass, + extra_check=lambda m: record_match( + self.Match(m, self.QUANT_OP, self.FUSED_OP))) + + class Match(QuantMultiOutputMatch): + + def process(self): + # Find the nodes in the match that we need to rebind + rms_node = self.find_auto_fn(RMS_ADD_OP) + quant_node = self.find_auto_fn(self.QUANT_OP) + + assert len(rms_node.users) == 2 + assert len(quant_node.users) == 1 + + # First, insert a new auto_functionalized node for the fused op, + # as well as getitem nodes to extract the result and residual. + # The auto_fn node returns a tuple of (None, result, residual). + # + # The resulting graph looks like this: + # at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa + # result_node_new = at[1] + # residual_node_new = at[2] + with self.inserting_after_match(): + # Missing epsilon, scalars cannot be inputs to the pattern + kwargs = self.match.kwargs.copy() + + # 0 is always None + fused_return_mapping = {1: (quant_node, 1), 2: (rms_node, 2)} + self.insert_fused_node(fused_return_mapping, + epsilon=rms_node.kwargs["epsilon"], + **kwargs) + + +class RMSNormDynamicQuantPattern(RMSNormQuantPattern): + + def __init__(self, + epsilon: float, + quant_dtype: torch.dtype, + per_tensor: bool, + symmetric=True): + key = FusedRMSQuantKey(fused_add=False, + quant=QuantKey(dtype=quant_dtype, + static=False, + per_tensor=per_tensor, + symmetric=symmetric)) + super().__init__(epsilon, key) + + def register(self, pm_pass: PatternMatcherPass, + record_match: Callable[[MultiOutputMatch], bool]): + + def pattern(result: torch.Tensor, result_rms: torch.Tensor, + input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at1 = auto_functionalized(RMS_OP, + result=result_rms, + input=input, + weight=weight, + epsilon=self.epsilon) + at2 = auto_functionalized(self.QUANT_OP, + result=result, + input=at1[1], + scale=scale, + scale_ub=None) + + # result, scale + return at2[1], at2[2] + + def replacement(result: torch.Tensor, result_rms: torch.Tensor, + input: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + scale_ub=None, + residual=None) + + # result, scale + return at[1], at[2] + + inputs = [ + torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result + empty_bf16(5, 4), # result_rms + empty_bf16(5, 4), # input + empty_bf16(1, 5), # weight + empty_fp32(1, 1) # scale + ] + + pm.register_replacement( + pattern, + replacement, + inputs, + pm.fwd_only, + pm_pass, + extra_check=lambda m: record_match( + self.Match(m, self.QUANT_OP, self.FUSED_OP))) + + class Match(QuantMultiOutputMatch): + + def process(self): + # Find the nodes in the match that we need to rebind + rms_node = self.find_auto_fn(RMS_OP) + quant_node = self.find_auto_fn(self.QUANT_OP) + + assert len(rms_node.users) == 1 + assert len(quant_node.users) == 2 + + # First, insert a new auto_functionalized node for the fused op, + # as well as getitem nodes to extract the result and scale. + # The auto_fn node returns a tuple of (None, result, scale). + # + # The resulting graph looks like this: + # at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa + # result_node_new = at[1] + # scale_node_new = at[2] + with self.inserting_after_match(): + # Missing epsilon, scalars cannot be inputs to the pattern + kwargs = self.match.kwargs.copy() + del kwargs["result_rms"] # not used in the fused op + + fused_return_mapping = {1: (quant_node, 1), 2: (quant_node, 2)} + self.insert_fused_node( + fused_return_mapping, + epsilon=rms_node.kwargs["epsilon"], + scale_ub=None, # not used but required + residual=None, # not used but required + **kwargs) + + +class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern): + + def __init__(self, + epsilon: float, + quant_dtype: torch.dtype, + per_tensor: bool = True, + symmetric=True): + key = FusedRMSQuantKey(fused_add=True, + quant=QuantKey(dtype=quant_dtype, + static=False, + per_tensor=per_tensor, + symmetric=symmetric)) + super().__init__(epsilon, key) + + def register(self, pm_pass: PatternMatcherPass, + record_match: Callable[[MultiOutputMatch], bool]): + + def pattern(result: torch.Tensor, input: torch.Tensor, + residual: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(RMS_ADD_OP, + input=input, + residual=residual, + weight=weight, + epsilon=self.epsilon) + at1 = auto_functionalized(self.QUANT_OP, + result=result, + input=at[1], + scale=scale, + scale_ub=None) + + # result, residual, scale + return at1[1], at[2], at1[2] + + def replacement(result: torch.Tensor, input: torch.Tensor, + residual: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor): + at = auto_functionalized(self.FUSED_OP, + result=result, + input=input, + weight=weight, + scale=scale, + epsilon=self.epsilon, + scale_ub=None, + residual=residual) + + # result, residual, scale + return at[1], at[3], at[2] + + inputs = [ + torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result + empty_bf16(5, 4), # input + empty_bf16(5, 4), # residual + empty_bf16(1, 5), # weight + empty_fp32(1, 1) # scale + ] + + pm.register_replacement( + pattern, + replacement, + inputs, + pm.fwd_only, + pm_pass, + extra_check=lambda m: record_match( + self.Match(m, self.QUANT_OP, self.FUSED_OP))) + + class Match(QuantMultiOutputMatch): + + def process(self): + # Find the nodes in the match that we need to rebind + rms_node = self.find_auto_fn(RMS_ADD_OP) + quant_node = self.find_auto_fn(self.QUANT_OP) + + assert len(rms_node.users) == 2 + assert len(quant_node.users) == 2 + + # First, insert a new auto_functionalized node for the fused op, + # as well as getitem nodes to extract result, scale, and residual. + # The auto_fn node returns a tuple (None, result, scale, residual). + # + # The resulting graph looks like this: + # at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa + # result_node_new = at[1] + # scale_node_new = at[2] + # residual_node_new = at[3] + with self.inserting_after_match(): + # Missing epsilon, scalars cannot be inputs to the pattern + kwargs = self.match.kwargs.copy() + + fused_return_mapping = { + 1: (quant_node, 1), # result + 2: (quant_node, 2), # scale + 3: (rms_node, 2), # residual + } + self.insert_fused_node( + fused_return_mapping, + epsilon=rms_node.kwargs["epsilon"], + scale_ub=None, # not used but required + **kwargs) class FusionPass(VllmInductorPass): @@ -158,41 +546,39 @@ class FusionPass(VllmInductorPass): "FusionPass singleton instance already exists" super().__init__(config) - self.matches: List[Match] = [] + self.matches: List[MultiOutputMatch] = [] self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="fusion_pass") - # Fuse rms_norm + static_scaled_fp8_quant into - # rms_norm_static_fp8_quant - inputs = [ - empty_fp8(5, 4), - empty_bf16(5, 4), - empty_bf16(5, 4), - empty_bf16(1, 5), - empty_fp32(1, 1) - ] - register_replacement(rms_pattern_static, rms_replacement_static, - inputs, fwd_only, self.patterns) + for epsilon in [1e-5, 1e-6]: + # Fuse rms_norm + static fp8 quant + RMSNormStaticQuantPattern(epsilon, + FP8_DTYPE).register(self.patterns) - # Fuse fused_add_rms_norm + static_scaled_fp8_quant into - # fused_add_rms_norm_static_fp8_quant - # Because pattern has 2 outputs, we need to manually process the match - # (see process_matches) - inputs = [ - empty_fp8(5, 4), - empty_bf16(5, 4), - empty_bf16(5, 4), - empty_bf16(1, 5), - empty_fp32(1, 1) - ] - register_replacement(rms_pattern_residual_static, - rms_replacement_residual_static, - inputs, - fwd_only, - self.patterns, - extra_check=lambda m: self.record_match(m)) + # Matches for patterns below have 2 or more outputs, + # so we need to process them manually (see process_matches) - def record_match(self, match: Match) -> bool: + # Fuse rms_norm + static fp8 quant + FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register( + self.patterns, self.record_match) + + # Fuse rms_norm + dynamic per-token fp8 quant + RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE, + per_tensor=False).register( + self.patterns, self.record_match) + + # Fuse fused_add_rms_norm + dynamic per-token fp8 quant + FusedAddRMSNormDynamicQuantPattern(epsilon, + FP8_DTYPE, + per_tensor=False).register( + self.patterns, + self.record_match) + + # WARNING: This is a hack to clear the pattern matcher cache + # and allow multiple values of epsilon. + torch._inductor.pattern_matcher._seen_patterns.clear() + + def record_match(self, match: MultiOutputMatch) -> bool: # Hijack the extra_check to record the match and # save it for post-processing. self.matches.append(match) @@ -200,83 +586,20 @@ class FusionPass(VllmInductorPass): # Return False to prevent automatic replacement. return False - def process_matches(self, graph: torch.fx.Graph): + def process_matches(self, graph: fx.Graph): """ Manually process multi-output matches and replace them with fused nodes. - This is necessary because the automatic replacement for multi-output - matches is broken: https://github.com/pytorch/pytorch/issues/137280 + See MultiOutputMatch for more details. """ for match in self.matches: - # To avoid use-before-definition errors, insert replacement nodes - # after the last node in the match. - # match.nodes is not guaranteed to be sorted. - # Find the last node in the match. - for last_node_in_match in reversed(graph.nodes): - if last_node_in_match in match.nodes: - break - else: - raise ValueError("No nodes in graph") - - # Insert a new auto_functionalized node for the fused operation, - # as well as getitem nodes to extract the result and residual. - # The auto_functionalized node returns a tuple of - # (None, result, residual) - None is the function return value. - # The resulting graph looks like this: - # at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa - # result_node_new = at[1] - # residual_node_new = at[2] - with graph.inserting_after(last_node_in_match): - kwargs = match.kwargs - kwargs["epsilon"] = 1e-5 # Currently hard-coded in RMSNorm - - fused_node = graph.call_function( - auto_functionalized, - (torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, - ), - kwargs=kwargs) - - graph.inserting_after(fused_node) - result_node_new = graph.call_function(operator.getitem, - (fused_node, 1)) - residual_node_new = graph.call_function( - operator.getitem, (fused_node, 2)) - - # Last part of replacement is rebinding the users of nodes in the - # match to use the new nodes. - - # Find the nodes in the match that we need to rebind - rms_node = find_auto_fn(match.nodes, - torch.ops._C.fused_add_rms_norm.default) - quant_node = find_auto_fn( - match.nodes, torch.ops._C.static_scaled_fp8_quant.default) - - assert len(rms_node.users) == 2 - assert len(quant_node.users) == 1 - - # meta["val"] is used by de-functionalization and has to contain the - # value of the node (tuple of tensors) that would be returned by the - # functionalized node during tracing. - - rms_tup = rms_node.meta["val"] - quant_tup = quant_node.meta["val"] - - # The result of fused_node must be a tuple with the first element - # None (the function return value) and the remaining elements - # representing the mutated inputs. - fused_tup = (None, quant_tup[1], rms_tup[1], rms_tup[2]) - fused_node.meta["val"] = fused_tup - - # Find the getitem nodes and replace their uses with the new nodes. - # The old nodes will be removed by DCE at the end of the pass. - find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new) - find_getitem(quant_node, 1).replace_all_uses_with(result_node_new) + match.process() # Finally, remove matched nodes graph.eliminate_dead_code() assert all(node not in graph.nodes for match in self.matches - for node in match.nodes) + for node in match.match.nodes) - def __call__(self, graph: torch.fx.Graph): + def __call__(self, graph: fx.Graph): self.begin() self.dump_graph(graph, "before_fusion") diff --git a/vllm/compilation/fx_utils.py b/vllm/compilation/fx_utils.py new file mode 100644 index 00000000..924e26f2 --- /dev/null +++ b/vllm/compilation/fx_utils.py @@ -0,0 +1,42 @@ +import operator +from typing import Iterable, Optional + +from torch import fx +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._ops import OpOverload + + +def is_func(node: fx.Node, target) -> bool: + return node.op == "call_function" and node.target == target + + +# Returns the first auto_functionalized node with the given op (if it exists) +def find_auto_fn_maybe(nodes: Iterable[fx.Node], + op: OpOverload) -> Optional[fx.Node]: + for node in nodes: + if is_func(node, auto_functionalized) and node.args[0] == op: # noqa + return node + return None + + +# Returns the first auto_functionalized node with the given op +def find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node: + node = find_auto_fn_maybe(nodes, op) + assert node is not None, f"Could not find {op} in nodes {nodes}" + return node + + +# Returns the getitem node that extracts the idx-th element from node +# (if it exists) +def find_getitem_maybe(node: fx.Node, idx: int) -> Optional[fx.Node]: + for user in node.users: + if is_func(user, operator.getitem) and user.args[1] == idx: + return user + return None + + +# Returns the getitem node that extracts the idx-th element from node +def find_getitem(node: fx.Node, idx: int) -> fx.Node: + ret = find_getitem_maybe(node, idx) + assert ret is not None, f"Could not find getitem {idx} in node {node}" + return ret diff --git a/vllm/compilation/multi_output_match.py b/vllm/compilation/multi_output_match.py new file mode 100644 index 00000000..0ad648ab --- /dev/null +++ b/vllm/compilation/multi_output_match.py @@ -0,0 +1,105 @@ +import abc +import operator +from abc import abstractmethod +from typing import Iterable, List, Tuple + +from torch import fx +from torch._higher_order_ops.auto_functionalize import auto_functionalized +from torch._inductor import pattern_matcher as pm +from torch._ops import OpOverload + +from vllm.compilation.fx_utils import find_auto_fn + + +class MultiOutputMatch(abc.ABC): + """ + This class provides utilities to process multi-output matches and + manually insert replacements. + + This is necessary because the automatic replacement for multi-output + matches is broken: https://github.com/pytorch/pytorch/issues/137280 + """ + + def __init__(self, match: pm.Match): + self.match = match + + @abstractmethod + def process(self): + """ + Process a multi-output match and manually insert the replacement. + + This method should: + 1. Insert the replacement nodes after the last node in the match. + 2. Rebind the users of nodes in the match to use the new nodes. + 3. Set meta["val"] for de-functionalization. + + The result of an auto-functionalized node is a tuple of tensors. + The first element is the return value of the function, usually None. + The remaining elements are the mutated args of the function. + + All auto-functionalized nodes must contain a proper meta["val"], + as it is used by de-functionalization. meta["val"] has to contain the + value of the node (tuple of tensors) that would be returned by the + functionalized node during tracing. + + Existing nodes in the graph all have this property set, but we have + to set it manually for new nodes we insert. + + Example: + # op schema: foo(a: Tensor!, b: Tensor, c: Tensor!) -> None + at = auto_functionalized(torch.ops._C.foo.default, a, b, c) + # at.meta["val"] = (None, a, c) + """ + raise NotImplementedError + + @property + def nodes(self) -> List[fx.Node]: + return self.match.nodes + + @property + def graph(self) -> fx.Graph: + return self.match.graph + + def find_auto_fn(self, op) -> fx.Node: + """ + Find the first auto_functionalized node with the given op in the match. + """ + return find_auto_fn(self.nodes, op) + + def inserting_after_match(self): + """ + Insert nodes after the last node in the match. + This is done to avoid use-before-definition errors after inserting + replacement nodes. + """ + + # match.nodes is not guaranteed to be sorted. + # Find the last node in the match. + for last_node_in_match in reversed(self.graph.nodes): + if last_node_in_match in self.match.nodes: + break + else: + raise ValueError("No nodes in graph") + + return self.graph.inserting_after(last_node_in_match) + + def insert_getitems(self, tuple_node: fx.Node, + indices: Iterable[int]) -> Tuple[fx.Node, ...]: + """ + Insert operator.getitem nodes to extract elements from a tuple node. + + :param tuple_node: The tuple node to extract elements from. + :param indices: The indices of the elements to extract. + :return: Tuple of the new getitem nodes, corresponding to the indices. + """ + with self.graph.inserting_after(tuple_node): + return tuple( + self.graph.call_function(operator.getitem, (tuple_node, idx)) + for idx in indices) + + def insert_auto_fn(self, op: OpOverload, kwargs): + """ + Insert an auto_functionalized node with the given op and kwargs. + """ + return self.graph.call_function(auto_functionalized, (op, ), + kwargs=kwargs) diff --git a/vllm/compilation/reshapes.py b/vllm/compilation/reshapes.py index 63a369fe..ba28b1f0 100644 --- a/vllm/compilation/reshapes.py +++ b/vllm/compilation/reshapes.py @@ -5,7 +5,8 @@ from torch import SymInt from vllm.logger import init_logger -from .vllm_inductor_pass import VllmInductorPass, is_func +from .fx_utils import is_func +from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py index dbf6b8f7..b8c52a7f 100644 --- a/vllm/compilation/vllm_inductor_pass.py +++ b/vllm/compilation/vllm_inductor_pass.py @@ -16,10 +16,6 @@ from .inductor_pass import InductorPass logger = init_logger(__name__) -def is_func(node: torch.fx.Node, target) -> bool: - return node.op == "call_function" and node.target == target - - class VllmInductorPass(InductorPass): """ An inductor pass with access to vLLM PassConfig.