[torch.compile] Dynamic fp8 + rms_norm fusion (#10906)
Signed-off-by: luka <luka@neuralmagic.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
parent
78ed8f57d8
commit
30870b4f66
@ -196,6 +196,7 @@ set(VLLM_EXT_SRC
|
|||||||
"csrc/quantization/gptq/q_gemm.cu"
|
"csrc/quantization/gptq/q_gemm.cu"
|
||||||
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
|
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
|
||||||
"csrc/quantization/fp8/common.cu"
|
"csrc/quantization/fp8/common.cu"
|
||||||
|
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
|
||||||
"csrc/quantization/gguf/gguf_kernel.cu"
|
"csrc/quantization/gguf/gguf_kernel.cu"
|
||||||
"csrc/cuda_utils_kernels.cu"
|
"csrc/cuda_utils_kernels.cu"
|
||||||
"csrc/prepare_inputs/advance_step.cu"
|
"csrc/prepare_inputs/advance_step.cu"
|
||||||
|
173
benchmarks/fused_kernels/layernorm_rms_benchmarks.py
Normal file
173
benchmarks/fused_kernels/layernorm_rms_benchmarks.py
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
import pickle as pkl
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from itertools import product
|
||||||
|
from typing import Callable, Iterable, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.benchmark as TBenchmark
|
||||||
|
from torch.utils.benchmark import Measurement as TMeasurement
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import vllm._custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class bench_params_t:
|
||||||
|
num_tokens: int
|
||||||
|
hidden_size: int
|
||||||
|
add_residual: bool
|
||||||
|
dtype: torch.dtype
|
||||||
|
|
||||||
|
def description(self):
|
||||||
|
return (f'N {self.num_tokens} '
|
||||||
|
f'x D {self.hidden_size} '
|
||||||
|
f'x R {self.add_residual} '
|
||||||
|
f'x DT {self.dtype}')
|
||||||
|
|
||||||
|
|
||||||
|
def get_bench_params() -> List[bench_params_t]:
|
||||||
|
## Test Fixtures
|
||||||
|
NUM_TOKENS = [2**x for x in range(11)]
|
||||||
|
HIDDEN_SIZES = list(range(1024, 8129, 1024))
|
||||||
|
ADD_RESIDUAL = [True, False]
|
||||||
|
DTYPES = [torch.bfloat16, torch.float]
|
||||||
|
|
||||||
|
combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES)
|
||||||
|
bench_params = list(map(lambda x: \
|
||||||
|
bench_params_t(x[0], x[1], x[2], x[3]), combinations))
|
||||||
|
return bench_params
|
||||||
|
|
||||||
|
|
||||||
|
# Reference impls
|
||||||
|
def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor],
|
||||||
|
quant_dtype: torch.dtype):
|
||||||
|
# Norm
|
||||||
|
torch_out = None
|
||||||
|
if residual is None:
|
||||||
|
torch_out = rms_norm_layer.forward_cuda(x, residual)
|
||||||
|
else:
|
||||||
|
torch_out, _ = rms_norm_layer.forward_cuda(x, residual)
|
||||||
|
|
||||||
|
# Quant
|
||||||
|
torch_out, _, _ = ops.scaled_int8_quant(torch_out)
|
||||||
|
|
||||||
|
|
||||||
|
def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor],
|
||||||
|
quant_dtype: torch.dtype):
|
||||||
|
# Norm
|
||||||
|
torch_out = None
|
||||||
|
if residual is None:
|
||||||
|
torch_out = rms_norm_layer.forward_cuda(x, residual)
|
||||||
|
else:
|
||||||
|
torch_out, _ = rms_norm_layer.forward_cuda(x, residual)
|
||||||
|
|
||||||
|
# Quant
|
||||||
|
torch_out, _ = ops.scaled_fp8_quant(torch_out)
|
||||||
|
|
||||||
|
|
||||||
|
def fused_impl(
|
||||||
|
rms_norm_layer: RMSNorm, # this stores the weights
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor],
|
||||||
|
quant_dtype: torch.dtype):
|
||||||
|
out, _ = ops.rms_norm_dynamic_per_token_quant(x,
|
||||||
|
rms_norm_layer.weight,
|
||||||
|
1e-6,
|
||||||
|
quant_dtype,
|
||||||
|
residual=residual)
|
||||||
|
|
||||||
|
|
||||||
|
# Bench functions
|
||||||
|
def bench_fn(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor,
|
||||||
|
quant_dtype: torch.dtype, label: str, sub_label: str,
|
||||||
|
fn: Callable, description: str) -> TMeasurement:
|
||||||
|
|
||||||
|
min_run_time = 1
|
||||||
|
|
||||||
|
globals = {
|
||||||
|
"rms_norm_layer": rms_norm_layer,
|
||||||
|
"x": x,
|
||||||
|
"residual": residual,
|
||||||
|
"quant_dtype": quant_dtype,
|
||||||
|
"fn": fn,
|
||||||
|
}
|
||||||
|
return TBenchmark.Timer(
|
||||||
|
stmt="fn(rms_norm_layer, x, residual, quant_dtype)",
|
||||||
|
globals=globals,
|
||||||
|
label=label,
|
||||||
|
sub_label=sub_label,
|
||||||
|
description=description,
|
||||||
|
).blocked_autorange(min_run_time=min_run_time)
|
||||||
|
|
||||||
|
def bench(params: bench_params_t, label: str, sub_label: str) \
|
||||||
|
-> Iterable[TMeasurement]:
|
||||||
|
|
||||||
|
# Make inputs
|
||||||
|
layer = RMSNorm(params.hidden_size, 1e-6).to(dtype=params.dtype)
|
||||||
|
# Make weights
|
||||||
|
layer.weight.data.normal_(mean=1.0, std=0.1)
|
||||||
|
# Make inputs
|
||||||
|
scale = 1 / params.hidden_size
|
||||||
|
x = torch.randn(params.num_tokens,
|
||||||
|
params.hidden_size,
|
||||||
|
dtype=params.dtype,
|
||||||
|
device='cuda') * scale
|
||||||
|
residual = (torch.randn_like(x) * scale).to(device='cuda') \
|
||||||
|
if params.add_residual else None
|
||||||
|
|
||||||
|
timers = []
|
||||||
|
|
||||||
|
# unfused int8 impl.
|
||||||
|
timers.append(
|
||||||
|
bench_fn(layer, x, residual, torch.int8, label, sub_label,
|
||||||
|
unfused_int8_impl, "unfused_int8_impl"))
|
||||||
|
|
||||||
|
# unfused fp8 impl.
|
||||||
|
timers.append(
|
||||||
|
bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label,
|
||||||
|
unfused_fp8_impl, "unfused_fp8_impl"))
|
||||||
|
|
||||||
|
# fused int8 impl.
|
||||||
|
timers.append(
|
||||||
|
bench_fn(layer, x, residual, torch.int8, label, sub_label, fused_impl,
|
||||||
|
"fused_int8_impl"))
|
||||||
|
|
||||||
|
# fused fp8 impl.
|
||||||
|
timers.append(
|
||||||
|
bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label,
|
||||||
|
fused_impl, "fused_fp8_impl"))
|
||||||
|
|
||||||
|
print_timers(timers)
|
||||||
|
|
||||||
|
return timers
|
||||||
|
|
||||||
|
|
||||||
|
# launch bench
|
||||||
|
# runner
|
||||||
|
def print_timers(timers: Iterable[TMeasurement]):
|
||||||
|
compare = TBenchmark.Compare(timers)
|
||||||
|
compare.print()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
torch.set_default_device('cuda')
|
||||||
|
bench_params = get_bench_params()
|
||||||
|
|
||||||
|
timers = []
|
||||||
|
for bp in tqdm(bench_params):
|
||||||
|
timers.extend(
|
||||||
|
bench(bp, "rms-norm-dynamic-per-token-quant", bp.description()))
|
||||||
|
print_timers(timers)
|
||||||
|
|
||||||
|
# pickle all the results
|
||||||
|
timestamp = int(time.time())
|
||||||
|
with open(f"rms_norm_dpt_quant-{timestamp}.pkl", "wb") as f:
|
||||||
|
pkl.dump(timers, f)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
@ -14,6 +14,20 @@
|
|||||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||||
|
|
||||||
|
// TODO(luka/varun): use FP8_TYPE macro after refactoring
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
|
||||||
|
#else
|
||||||
|
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
|
||||||
|
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
|
||||||
|
|
||||||
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
|
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||||
|
@ -66,6 +66,14 @@ void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out,
|
|||||||
torch::Tensor& weight,
|
torch::Tensor& weight,
|
||||||
torch::Tensor& scale, double epsilon);
|
torch::Tensor& scale, double epsilon);
|
||||||
|
|
||||||
|
void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
|
||||||
|
torch::Tensor const& input,
|
||||||
|
torch::Tensor const& weight,
|
||||||
|
torch::Tensor& scales,
|
||||||
|
double const epsilon,
|
||||||
|
std::optional<torch::Tensor> scale_ub,
|
||||||
|
std::optional<torch::Tensor> residual);
|
||||||
|
|
||||||
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
|
||||||
torch::Tensor& key, int64_t head_size,
|
torch::Tensor& key, int64_t head_size,
|
||||||
torch::Tensor& cos_sin_cache, bool is_neox);
|
torch::Tensor& cos_sin_cache, bool is_neox);
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "quantization/vectorization.cuh"
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <c10/core/ScalarType.h>
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
#include <c10/util/Float8_e4m3fn.h>
|
#include <c10/util/Float8_e4m3fn.h>
|
||||||
@ -15,6 +18,7 @@ using FP8_TYPE = c10::Float8_e4m3fnuz;
|
|||||||
// issue when running dynamic quantization. Here use 224.0f for rocm.
|
// issue when running dynamic quantization. Here use 224.0f for rocm.
|
||||||
constexpr auto FP8_E4M3_MAX = 224.0f;
|
constexpr auto FP8_E4M3_MAX = 224.0f;
|
||||||
#endif
|
#endif
|
||||||
|
constexpr static auto kFp8Type = c10::CppTypeToScalarType<FP8_TYPE>::value;
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
@ -89,22 +93,6 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t>
|
|
||||||
struct __align__(8) vec4_t {
|
|
||||||
scalar_t x;
|
|
||||||
scalar_t y;
|
|
||||||
scalar_t z;
|
|
||||||
scalar_t w;
|
|
||||||
};
|
|
||||||
|
|
||||||
typedef struct __align__(4) {
|
|
||||||
FP8_TYPE x;
|
|
||||||
FP8_TYPE y;
|
|
||||||
FP8_TYPE z;
|
|
||||||
FP8_TYPE w;
|
|
||||||
}
|
|
||||||
float8x4_t;
|
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
|
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
|
||||||
int64_t const num_elems, int const tid,
|
int64_t const num_elems, int const tid,
|
||||||
@ -139,10 +127,10 @@ __device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
|
|||||||
float const scale,
|
float const scale,
|
||||||
int64_t const num_elems,
|
int64_t const num_elems,
|
||||||
int const tid, int const step) {
|
int const tid, int const step) {
|
||||||
|
using float8x4_t = q8x4_t<FP8_TYPE>;
|
||||||
// Vectorized input/output to better utilize memory bandwidth.
|
// Vectorized input/output to better utilize memory bandwidth.
|
||||||
vec4_t<scalar_t> const* vectorized_in =
|
auto const* vectorized_in = reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
||||||
reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
auto* vectorized_out = reinterpret_cast<float8x4_t*>(out);
|
||||||
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
|
|
||||||
|
|
||||||
int64_t const num_vec_elems = num_elems >> 2;
|
int64_t const num_vec_elems = num_elems >> 2;
|
||||||
|
|
||||||
|
@ -0,0 +1,160 @@
|
|||||||
|
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
|
#include "../../dispatch_utils.h"
|
||||||
|
#include "layernorm_utils.cuh"
|
||||||
|
#include "quant_conversions.cuh"
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
template <typename scalar_t, typename scalar_out_t, bool has_residual = false>
|
||||||
|
__device__ void rms_norm_dynamic_per_token_quant_vec(
|
||||||
|
scalar_out_t* __restrict__ out, // [..., hidden_size]
|
||||||
|
float* __restrict__ scales, // [num_tokens]
|
||||||
|
scalar_t const* __restrict__ input, // [..., hidden_size]
|
||||||
|
scalar_t const* __restrict__ weight, // [hidden_size]
|
||||||
|
float const* scale_ub, float const var_epsilon,
|
||||||
|
float const min_scaling_factor, int32_t const hidden_size,
|
||||||
|
scalar_t* __restrict__ residual = nullptr) {
|
||||||
|
float rms = 0.0f;
|
||||||
|
float token_scale = 0.0f;
|
||||||
|
|
||||||
|
// Compute rms
|
||||||
|
vllm::vectorized::compute_rms<scalar_t, has_residual>(
|
||||||
|
&rms, input, hidden_size, var_epsilon, residual);
|
||||||
|
|
||||||
|
// Compute scale
|
||||||
|
vllm::vectorized::compute_dynamic_per_token_scales<scalar_t, scalar_out_t,
|
||||||
|
has_residual>(
|
||||||
|
&token_scale, scales, input, weight, rms, scale_ub, min_scaling_factor,
|
||||||
|
hidden_size, residual);
|
||||||
|
|
||||||
|
// RMS Norm + Quant
|
||||||
|
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
|
||||||
|
vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, true,
|
||||||
|
has_residual>(
|
||||||
|
out, input, weight, rms, 1.0f / token_scale, hidden_size, residual);
|
||||||
|
} else {
|
||||||
|
// FP8 - Do not invert token_scale for exact match with FBGemm
|
||||||
|
vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, false,
|
||||||
|
has_residual>(
|
||||||
|
out, input, weight, rms, token_scale, hidden_size, residual);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RMS norm + quant kernel
|
||||||
|
template <typename scalar_t, typename scalar_out_t, bool has_residual = false>
|
||||||
|
__global__ void rms_norm_dynamic_per_token_quant_kernel(
|
||||||
|
scalar_out_t* __restrict__ out, // [..., hidden_size]
|
||||||
|
float* __restrict__ scales, // [num_tokens]
|
||||||
|
scalar_t const* __restrict__ input, // [..., hidden_size]
|
||||||
|
scalar_t const* __restrict__ weight, // [hidden_size]
|
||||||
|
float const* scale_ub, float const var_epsilon,
|
||||||
|
float const min_scaling_factor, int32_t const hidden_size,
|
||||||
|
scalar_t* __restrict__ residual = nullptr) {
|
||||||
|
// For vectorization, token_input and token_output pointers need to be
|
||||||
|
// aligned at 8-byte and 4-byte addresses respectively.
|
||||||
|
bool const can_vectorize = hidden_size % 4 == 0;
|
||||||
|
|
||||||
|
if (can_vectorize) {
|
||||||
|
return rms_norm_dynamic_per_token_quant_vec<scalar_t, scalar_out_t,
|
||||||
|
has_residual>(
|
||||||
|
out, scales, input, weight, scale_ub, var_epsilon, min_scaling_factor,
|
||||||
|
hidden_size, residual);
|
||||||
|
}
|
||||||
|
|
||||||
|
float rms = 0.0f;
|
||||||
|
float token_scale = 0.0f;
|
||||||
|
|
||||||
|
// Compute RMS
|
||||||
|
vllm::compute_rms<scalar_t, has_residual>(&rms, input, hidden_size,
|
||||||
|
var_epsilon, residual);
|
||||||
|
// Compute Scale
|
||||||
|
vllm::compute_dynamic_per_token_scales<scalar_t, scalar_out_t, has_residual>(
|
||||||
|
&token_scale, scales, input, weight, rms, scale_ub, min_scaling_factor,
|
||||||
|
hidden_size, residual);
|
||||||
|
|
||||||
|
// RMS Norm + Quant
|
||||||
|
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
|
||||||
|
vllm::norm_and_quant<scalar_t, scalar_out_t, true, has_residual>(
|
||||||
|
out, input, weight, rms, 1.0f / token_scale, hidden_size, residual);
|
||||||
|
} else {
|
||||||
|
// FP8 - Do not invert s_token_scale for exact match with FBGemm
|
||||||
|
vllm::norm_and_quant<scalar_t, scalar_out_t, false, has_residual>(
|
||||||
|
out, input, weight, rms, token_scale, hidden_size, residual);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace vllm
|
||||||
|
|
||||||
|
// Residual add + RMS norm + dynamic per token
|
||||||
|
template <typename scalar_in_t>
|
||||||
|
void rms_norm_dynamic_per_token_quant_dispatch(
|
||||||
|
torch::Tensor& out, // [..., hidden_size]
|
||||||
|
torch::Tensor const& input, // [..., hidden_size]
|
||||||
|
torch::Tensor const& weight, // [hidden_size]
|
||||||
|
torch::Tensor& scales, // [num_tokens]
|
||||||
|
double const var_epsilon, // Variance epsilon used in norm calculation
|
||||||
|
std::optional<at::Tensor> const& scale_ub,
|
||||||
|
std::optional<at::Tensor>& residual) {
|
||||||
|
int32_t hidden_size = input.size(-1);
|
||||||
|
int32_t num_tokens = input.numel() / hidden_size;
|
||||||
|
|
||||||
|
dim3 grid(num_tokens);
|
||||||
|
dim3 block(std::min(hidden_size, 1024));
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
|
const float min_scaling_factor =
|
||||||
|
out.dtype() == torch::kInt8
|
||||||
|
? std::numeric_limits<float>::epsilon()
|
||||||
|
: 1.0f / (std::numeric_limits<c10::Float8_e4m3fn>::max() * 512.f);
|
||||||
|
|
||||||
|
if (residual.has_value()) {
|
||||||
|
VLLM_DISPATCH_QUANT_TYPES(
|
||||||
|
out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] {
|
||||||
|
vllm::rms_norm_dynamic_per_token_quant_kernel<scalar_in_t, scalar_t,
|
||||||
|
true>
|
||||||
|
<<<grid, block, 0, stream>>>(
|
||||||
|
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
|
||||||
|
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
|
||||||
|
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
||||||
|
var_epsilon, min_scaling_factor, hidden_size,
|
||||||
|
residual->data_ptr<scalar_in_t>());
|
||||||
|
});
|
||||||
|
|
||||||
|
} else {
|
||||||
|
VLLM_DISPATCH_QUANT_TYPES(
|
||||||
|
out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] {
|
||||||
|
vllm::rms_norm_dynamic_per_token_quant_kernel<scalar_in_t, scalar_t,
|
||||||
|
false>
|
||||||
|
<<<grid, block, 0, stream>>>(
|
||||||
|
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
|
||||||
|
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
|
||||||
|
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
||||||
|
var_epsilon, min_scaling_factor, hidden_size, nullptr);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void rms_norm_dynamic_per_token_quant(
|
||||||
|
torch::Tensor& out, // [..., hidden_size]
|
||||||
|
torch::Tensor const& input, // [..., hidden_size]
|
||||||
|
torch::Tensor const& weight, // [hidden_size]
|
||||||
|
torch::Tensor& scales, // [num_tokens]
|
||||||
|
double const var_epsilon, // Variance epsilon used in norm calculation
|
||||||
|
std::optional<at::Tensor> scale_ub, std::optional<at::Tensor> residual) {
|
||||||
|
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
|
||||||
|
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
|
||||||
|
|
||||||
|
if (scale_ub.has_value()) {
|
||||||
|
TORCH_CHECK(out.dtype() == kFp8Type);
|
||||||
|
}
|
||||||
|
TORCH_CHECK(scales.dtype() == torch::kFloat32);
|
||||||
|
|
||||||
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
|
input.scalar_type(), "rms_norm_dynamic_per_token_quant_dispatch", [&] {
|
||||||
|
rms_norm_dynamic_per_token_quant_dispatch<scalar_t>(
|
||||||
|
out, input, weight, scales, var_epsilon, scale_ub, residual);
|
||||||
|
});
|
||||||
|
}
|
327
csrc/quantization/fused_kernels/layernorm_utils.cuh
Normal file
327
csrc/quantization/fused_kernels/layernorm_utils.cuh
Normal file
@ -0,0 +1,327 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
/**
|
||||||
|
* __device__ layernorm utilities.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "quantization/vectorization.cuh"
|
||||||
|
#include "quant_conversions.cuh"
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
#include <cub/cub.cuh>
|
||||||
|
#else
|
||||||
|
#include <hipcub/hipcub.hpp>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
// has_residual must be true, if residual is not a nullptr
|
||||||
|
template <typename scalar_t, bool has_residual = false>
|
||||||
|
__device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
|
||||||
|
int32_t const hidden_size, float const epsilon,
|
||||||
|
scalar_t const* __restrict__ residual = nullptr) {
|
||||||
|
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||||
|
// sum of squares
|
||||||
|
float ss = 0.0f;
|
||||||
|
|
||||||
|
for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||||
|
float x = static_cast<float>(input[token_offset + i]);
|
||||||
|
if constexpr (has_residual) {
|
||||||
|
x += static_cast<float>(residual[token_offset + i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
ss += x * x;
|
||||||
|
}
|
||||||
|
|
||||||
|
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||||
|
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||||
|
ss = BlockReduce(reduceStore).Reduce(ss, cub::Sum{}, blockDim.x);
|
||||||
|
|
||||||
|
__shared__ float s_rms;
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
s_rms = rsqrtf(ss / hidden_size + epsilon);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
*rms = s_rms;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t, typename scalar_out_t, bool has_residual = false>
|
||||||
|
__device__ void compute_dynamic_per_token_scales(
|
||||||
|
float* __restrict__ token_scale, float* __restrict__ all_token_scales,
|
||||||
|
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
|
||||||
|
float const rms, float const* __restrict__ scale_ub,
|
||||||
|
float const min_scaling_factor, int32_t const hidden_size,
|
||||||
|
scalar_t const* __restrict__ residual = nullptr) {
|
||||||
|
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||||
|
;
|
||||||
|
constexpr scalar_out_t qmax{std::numeric_limits<scalar_out_t>::max()};
|
||||||
|
|
||||||
|
float block_absmax_val_maybe = 0.0f;
|
||||||
|
for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||||
|
float x = static_cast<float>(input[token_offset + i]);
|
||||||
|
if constexpr (has_residual) {
|
||||||
|
x += static_cast<float>(residual[token_offset + i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
x = static_cast<float>(static_cast<scalar_t>(x * rms) * weight[i]);
|
||||||
|
block_absmax_val_maybe = fmaxf(block_absmax_val_maybe, fabsf(x));
|
||||||
|
}
|
||||||
|
|
||||||
|
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||||
|
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||||
|
block_absmax_val_maybe =
|
||||||
|
BlockReduce(reduceStore)
|
||||||
|
.Reduce(block_absmax_val_maybe, cub::Max{}, blockDim.x);
|
||||||
|
|
||||||
|
__shared__ float s_token_scale;
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
float scale = 0.0f;
|
||||||
|
if (scale_ub) {
|
||||||
|
scale = min(block_absmax_val_maybe, *scale_ub);
|
||||||
|
} else {
|
||||||
|
scale = block_absmax_val_maybe;
|
||||||
|
}
|
||||||
|
// token scale computation
|
||||||
|
scale = max(scale / qmax, min_scaling_factor);
|
||||||
|
s_token_scale = scale; // Shared memory store
|
||||||
|
all_token_scales[blockIdx.x] = scale; // Global output store
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
*token_scale = s_token_scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t, typename scalar_out_t, bool is_scale_inverted,
|
||||||
|
bool has_residual = false>
|
||||||
|
__device__ void norm_and_quant(scalar_out_t* __restrict__ output,
|
||||||
|
scalar_t const* __restrict__ input,
|
||||||
|
scalar_t const* __restrict__ weight,
|
||||||
|
float const rms, float const scale,
|
||||||
|
int32_t const hidden_size,
|
||||||
|
scalar_t* __restrict__ residual = nullptr) {
|
||||||
|
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||||
|
;
|
||||||
|
|
||||||
|
for (int32_t i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||||
|
float x = static_cast<float>(input[token_offset + i]);
|
||||||
|
if constexpr (has_residual) {
|
||||||
|
x += static_cast<float>(residual[token_offset + i]);
|
||||||
|
residual[token_offset + i] = static_cast<scalar_t>(x);
|
||||||
|
}
|
||||||
|
// Norm
|
||||||
|
x = static_cast<float>(static_cast<scalar_t>(x * rms) * weight[i]);
|
||||||
|
// Quant
|
||||||
|
output[token_offset + i] =
|
||||||
|
ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(x, scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace vectorized {
|
||||||
|
|
||||||
|
// Compute 1.0/rms(input)
|
||||||
|
// hidden_size must be a multiple of 4
|
||||||
|
template <typename scalar_t, bool has_residual = false>
|
||||||
|
__device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
|
||||||
|
int32_t const hidden_size, float const epsilon,
|
||||||
|
scalar_t const* __restrict__ residual = nullptr) {
|
||||||
|
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||||
|
|
||||||
|
// Vectorized input/output to better utilize memory bandwidth.
|
||||||
|
vec4_t<scalar_t> const* vec_input =
|
||||||
|
reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]);
|
||||||
|
vec4_t<scalar_t> const* vec_residual = nullptr;
|
||||||
|
if constexpr (has_residual) {
|
||||||
|
vec_residual =
|
||||||
|
reinterpret_cast<vec4_t<scalar_t> const*>(&residual[token_offset]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// sum of squares
|
||||||
|
float ss = 0.0f;
|
||||||
|
|
||||||
|
int32_t const num_vec_elems = hidden_size >> 2;
|
||||||
|
|
||||||
|
#pragma unroll 4
|
||||||
|
for (int32_t i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
|
||||||
|
vec4_t<scalar_t> in = vec_input[i];
|
||||||
|
|
||||||
|
vec4_t<float> x;
|
||||||
|
x.x = static_cast<float>(in.x);
|
||||||
|
x.y = static_cast<float>(in.y);
|
||||||
|
x.z = static_cast<float>(in.z);
|
||||||
|
x.w = static_cast<float>(in.w);
|
||||||
|
if constexpr (has_residual) {
|
||||||
|
vec4_t<scalar_t> r = vec_residual[i];
|
||||||
|
x.x += static_cast<float>(r.x);
|
||||||
|
x.y += static_cast<float>(r.y);
|
||||||
|
x.z += static_cast<float>(r.z);
|
||||||
|
x.w += static_cast<float>(r.w);
|
||||||
|
}
|
||||||
|
|
||||||
|
ss += x.x * x.x;
|
||||||
|
ss += x.y * x.y;
|
||||||
|
ss += x.z * x.z;
|
||||||
|
ss += x.w * x.w;
|
||||||
|
}
|
||||||
|
|
||||||
|
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||||
|
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||||
|
ss = BlockReduce(reduceStore).Reduce(ss, cub::Sum{}, blockDim.x);
|
||||||
|
|
||||||
|
__shared__ float s_rms;
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
s_rms = rsqrtf(ss / hidden_size + epsilon);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
*rms = s_rms;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Vectorized version of vllm::compute_dynamic_per_token_scales
|
||||||
|
// hidden_size must be a multiple of 4
|
||||||
|
template <typename scalar_t, typename scalar_out_t, bool has_residual = false>
|
||||||
|
__device__ void compute_dynamic_per_token_scales(
|
||||||
|
float* __restrict__ token_scale, float* __restrict__ all_token_scales,
|
||||||
|
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
|
||||||
|
float const rms, float const* __restrict__ scale_ub,
|
||||||
|
float const min_scaling_factor, int32_t const hidden_size,
|
||||||
|
scalar_t const* __restrict__ residual = nullptr) {
|
||||||
|
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||||
|
;
|
||||||
|
|
||||||
|
// Vectorized input/weight/residual to better utilize memory bandwidth.
|
||||||
|
vec4_t<scalar_t> const* vec_input =
|
||||||
|
reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]);
|
||||||
|
vec4_t<scalar_t> const* vec_weight =
|
||||||
|
reinterpret_cast<vec4_t<scalar_t> const*>(weight);
|
||||||
|
vec4_t<scalar_t> const* vec_residual = nullptr;
|
||||||
|
if constexpr (has_residual) {
|
||||||
|
vec_residual =
|
||||||
|
reinterpret_cast<vec4_t<scalar_t> const*>(&residual[token_offset]);
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr scalar_out_t qmax{std::numeric_limits<scalar_out_t>::max()};
|
||||||
|
|
||||||
|
int32_t const num_vec_elems = hidden_size >> 2;
|
||||||
|
float block_absmax_val_maybe = 0.0f;
|
||||||
|
|
||||||
|
#pragma unroll 4
|
||||||
|
for (int32_t i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
|
||||||
|
vec4_t<scalar_t> in = vec_input[i];
|
||||||
|
vec4_t<scalar_t> const w = vec_weight[i];
|
||||||
|
|
||||||
|
vec4_t<float> x;
|
||||||
|
x.x = static_cast<float>(in.x);
|
||||||
|
x.y = static_cast<float>(in.y);
|
||||||
|
x.z = static_cast<float>(in.z);
|
||||||
|
x.w = static_cast<float>(in.w);
|
||||||
|
if constexpr (has_residual) {
|
||||||
|
vec4_t<scalar_t> r = vec_residual[i];
|
||||||
|
x.x += static_cast<float>(r.x);
|
||||||
|
x.y += static_cast<float>(r.y);
|
||||||
|
x.z += static_cast<float>(r.z);
|
||||||
|
x.w += static_cast<float>(r.w);
|
||||||
|
}
|
||||||
|
|
||||||
|
block_absmax_val_maybe = fmaxf(
|
||||||
|
block_absmax_val_maybe, fabs(static_cast<scalar_t>(x.x * rms) * w.x));
|
||||||
|
block_absmax_val_maybe = fmaxf(
|
||||||
|
block_absmax_val_maybe, fabs(static_cast<scalar_t>(x.y * rms) * w.y));
|
||||||
|
block_absmax_val_maybe = fmaxf(
|
||||||
|
block_absmax_val_maybe, fabs(static_cast<scalar_t>(x.z * rms) * w.z));
|
||||||
|
block_absmax_val_maybe = fmaxf(
|
||||||
|
block_absmax_val_maybe, fabs(static_cast<scalar_t>(x.w * rms) * w.w));
|
||||||
|
}
|
||||||
|
|
||||||
|
using BlockReduce = cub::BlockReduce<float, 1024>;
|
||||||
|
__shared__ typename BlockReduce::TempStorage reduceStore;
|
||||||
|
block_absmax_val_maybe =
|
||||||
|
BlockReduce(reduceStore)
|
||||||
|
.Reduce(block_absmax_val_maybe, cub::Max{}, blockDim.x);
|
||||||
|
|
||||||
|
__shared__ float s_token_scale;
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
float scale = 0.0f;
|
||||||
|
if (scale_ub) {
|
||||||
|
scale = min(block_absmax_val_maybe, *scale_ub);
|
||||||
|
} else {
|
||||||
|
scale = block_absmax_val_maybe;
|
||||||
|
}
|
||||||
|
// token scale computation
|
||||||
|
scale = max(scale / qmax, min_scaling_factor);
|
||||||
|
s_token_scale = scale; // shared memory store
|
||||||
|
all_token_scales[blockIdx.x] = scale; // global output store
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
*token_scale = s_token_scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
// hidden_size must be a multiple of 4
|
||||||
|
template <typename scalar_t, typename scalar_out_t, bool is_scale_inverted,
|
||||||
|
bool has_residual = false>
|
||||||
|
__device__ void norm_and_quant(scalar_out_t* __restrict__ output,
|
||||||
|
scalar_t const* __restrict__ input,
|
||||||
|
scalar_t const* __restrict__ weight,
|
||||||
|
float const rms, float const scale,
|
||||||
|
int32_t const hidden_size,
|
||||||
|
scalar_t* __restrict__ residual = nullptr) {
|
||||||
|
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||||
|
;
|
||||||
|
|
||||||
|
// Vectorized input/output/weight/residual to better utilize memory bandwidth.
|
||||||
|
vec4_t<scalar_t> const* vec_input =
|
||||||
|
reinterpret_cast<vec4_t<scalar_t> const*>(&input[token_offset]);
|
||||||
|
vec4_t<scalar_t> const* vec_weight =
|
||||||
|
reinterpret_cast<vec4_t<scalar_t> const*>(weight);
|
||||||
|
q8x4_t<scalar_out_t>* vec_output =
|
||||||
|
reinterpret_cast<q8x4_t<scalar_out_t>*>(&output[token_offset]);
|
||||||
|
vec4_t<scalar_t>* vec_residual = nullptr;
|
||||||
|
if constexpr (has_residual) {
|
||||||
|
vec_residual = reinterpret_cast<vec4_t<scalar_t>*>(&residual[token_offset]);
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t const num_vec_elems = hidden_size >> 2;
|
||||||
|
|
||||||
|
// TODO(luka/varun) extract into type-agnostic vectorized quant function to
|
||||||
|
// replace scaled_fp8_conversion_vec
|
||||||
|
#pragma unroll 4
|
||||||
|
for (int32_t i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
|
||||||
|
vec4_t<scalar_t> const in = vec_input[i];
|
||||||
|
vec4_t<scalar_t> const w = vec_weight[i];
|
||||||
|
|
||||||
|
vec4_t<float> x;
|
||||||
|
x.x = static_cast<float>(in.x);
|
||||||
|
x.y = static_cast<float>(in.y);
|
||||||
|
x.z = static_cast<float>(in.z);
|
||||||
|
x.w = static_cast<float>(in.w);
|
||||||
|
if constexpr (has_residual) {
|
||||||
|
vec4_t<scalar_t> r = vec_residual[i];
|
||||||
|
x.x += static_cast<float>(r.x);
|
||||||
|
x.y += static_cast<float>(r.y);
|
||||||
|
x.z += static_cast<float>(r.z);
|
||||||
|
x.w += static_cast<float>(r.w);
|
||||||
|
// Update residual
|
||||||
|
r.x = static_cast<scalar_t>(x.x);
|
||||||
|
r.y = static_cast<scalar_t>(x.y);
|
||||||
|
r.z = static_cast<scalar_t>(x.z);
|
||||||
|
r.w = static_cast<scalar_t>(x.w);
|
||||||
|
vec_residual[i] = r;
|
||||||
|
}
|
||||||
|
|
||||||
|
q8x4_t<scalar_out_t> out;
|
||||||
|
out.x = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
|
||||||
|
static_cast<scalar_t>(x.x * rms) * w.x, scale);
|
||||||
|
out.y = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
|
||||||
|
static_cast<scalar_t>(x.y * rms) * w.y, scale);
|
||||||
|
out.z = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
|
||||||
|
static_cast<scalar_t>(x.z * rms) * w.z, scale);
|
||||||
|
out.w = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
|
||||||
|
static_cast<scalar_t>(x.w * rms) * w.w, scale);
|
||||||
|
vec_output[i] = out;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vectorized
|
||||||
|
|
||||||
|
} // namespace vllm
|
81
csrc/quantization/fused_kernels/quant_conversions.cuh
Normal file
81
csrc/quantization/fused_kernels/quant_conversions.cuh
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
/**
|
||||||
|
* __device__ helper functions to deal with float -> quant datatype conversion
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "quantization/vectorization.cuh"
|
||||||
|
// TODO(luka/varun):refactor common.cuh to use this file instead
|
||||||
|
#include "quantization/fp8/common.cuh"
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
// TODO(luka/varun): combine into common utilities for int8
|
||||||
|
// (with int8_quant_kernels.cu)
|
||||||
|
static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) {
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
static const float i8_min =
|
||||||
|
static_cast<float>(std::numeric_limits<int8_t>::min());
|
||||||
|
static const float i8_max =
|
||||||
|
static_cast<float>(std::numeric_limits<int8_t>::max());
|
||||||
|
// round
|
||||||
|
float dst = std::nearbyint(x);
|
||||||
|
// saturate
|
||||||
|
dst = std::clamp(dst, i8_min, i8_max);
|
||||||
|
return static_cast<int8_t>(dst);
|
||||||
|
#else
|
||||||
|
// CUDA path
|
||||||
|
uint32_t dst;
|
||||||
|
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
|
||||||
|
return reinterpret_cast<const int8_t&>(dst);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ FP8_TYPE float_to_fp8(float const x) {
|
||||||
|
float const r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
||||||
|
return static_cast<FP8_TYPE>(r);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename quant_type_t, bool is_scale_inverted, typename enable = void>
|
||||||
|
struct ScaledQuant;
|
||||||
|
|
||||||
|
template <typename quant_type_t, bool is_scale_inverted>
|
||||||
|
struct ScaledQuant<
|
||||||
|
quant_type_t, is_scale_inverted,
|
||||||
|
typename std::enable_if_t<std::is_same_v<quant_type_t, int8_t>>> {
|
||||||
|
static __device__ __forceinline__ quant_type_t quant_fn(float const x,
|
||||||
|
float const scale) {
|
||||||
|
if constexpr (is_scale_inverted) {
|
||||||
|
return float_to_int8_rn(x * scale);
|
||||||
|
} else {
|
||||||
|
return float_to_int8_rn(x / scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename quant_type_t, bool is_scale_inverted>
|
||||||
|
struct ScaledQuant<
|
||||||
|
quant_type_t, is_scale_inverted,
|
||||||
|
typename std::enable_if_t<std::is_same_v<quant_type_t, FP8_TYPE>>> {
|
||||||
|
static __device__ __forceinline__ quant_type_t quant_fn(float const x,
|
||||||
|
float const scale) {
|
||||||
|
if constexpr (is_scale_inverted) {
|
||||||
|
return float_to_fp8(x * scale);
|
||||||
|
} else {
|
||||||
|
return float_to_fp8(x / scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename scalar_t, typename quant_type_t, bool is_scale_inverted>
|
||||||
|
__device__ void scaled_quant_conversion(quant_type_t* __restrict__ output,
|
||||||
|
scalar_t const* __restrict__ input,
|
||||||
|
float const scale, int const tid,
|
||||||
|
int const num_elements,
|
||||||
|
int const step) {
|
||||||
|
for (int i = tid; i < num_elements; i += step) {
|
||||||
|
output[i] = ScaledQuant<quant_type_t, is_scale_inverted>(input[i], scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace vllm
|
33
csrc/quantization/vectorization.cuh
Normal file
33
csrc/quantization/vectorization.cuh
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
#pragma once
|
||||||
|
/**
|
||||||
|
* __device__ datatypes vectorized by 4
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Include both AMD and NVIDIA fp8 types to avoid circular import
|
||||||
|
// TODO(luka/varun) use FP8_TYPE instead after refactoring
|
||||||
|
#include <c10/util/Float8_e4m3fnuz.h>
|
||||||
|
#include <c10/util/Float8_e4m3fn.h>
|
||||||
|
|
||||||
|
namespace vllm {
|
||||||
|
|
||||||
|
// Vectorization containers
|
||||||
|
template <typename scalar_t>
|
||||||
|
struct __align__(8) vec4_t {
|
||||||
|
scalar_t x;
|
||||||
|
scalar_t y;
|
||||||
|
scalar_t z;
|
||||||
|
scalar_t w;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename quant_type_t>
|
||||||
|
struct __align__(4) q8x4_t {
|
||||||
|
static_assert(std::is_same_v<quant_type_t, int8_t> ||
|
||||||
|
std::is_same_v<quant_type_t, c10::Float8_e4m3fn> ||
|
||||||
|
std::is_same_v<quant_type_t, c10::Float8_e4m3fnuz>);
|
||||||
|
quant_type_t x;
|
||||||
|
quant_type_t y;
|
||||||
|
quant_type_t z;
|
||||||
|
quant_type_t w;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace vllm
|
@ -128,6 +128,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
|||||||
ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA,
|
ops.impl("fused_add_rms_norm_static_fp8_quant", torch::kCUDA,
|
||||||
&fused_add_rms_norm_static_fp8_quant);
|
&fused_add_rms_norm_static_fp8_quant);
|
||||||
|
|
||||||
|
// Fused Layernorm + Quant kernels
|
||||||
|
ops.def(
|
||||||
|
"rms_norm_dynamic_per_token_quant(Tensor! result, Tensor input, "
|
||||||
|
"Tensor weight, Tensor! scale, float epsilon, "
|
||||||
|
"Tensor? scale_ub, Tensor!? residual) -> ()");
|
||||||
|
ops.impl("rms_norm_dynamic_per_token_quant", torch::kCUDA,
|
||||||
|
&rms_norm_dynamic_per_token_quant);
|
||||||
|
|
||||||
// Rotary embedding
|
// Rotary embedding
|
||||||
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
|
// Apply GPT-NeoX or GPT-J style rotary embedding to query and key.
|
||||||
ops.def(
|
ops.def(
|
||||||
|
@ -4,10 +4,10 @@ import torch
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
|
||||||
from vllm.compilation.fusion import (FusionPass, find_auto_fn,
|
from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
|
||||||
find_auto_fn_maybe)
|
kFp8DynamicTokenSym, kFp8StaticTensorSym)
|
||||||
|
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
|
||||||
from vllm.compilation.reshapes import RedundantReshapesPass
|
from vllm.compilation.reshapes import RedundantReshapesPass
|
||||||
from vllm.compilation.vllm_inductor_pass import is_func
|
|
||||||
from vllm.config import CompilationConfig
|
from vllm.config import CompilationConfig
|
||||||
|
|
||||||
from .backend import TestBackend
|
from .backend import TestBackend
|
||||||
@ -35,12 +35,16 @@ prompts = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model",
|
@pytest.mark.parametrize(
|
||||||
["nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"])
|
"model, quant_key",
|
||||||
|
[("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e", kFp8StaticTensorSym),
|
||||||
|
("nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8_DYNAMIC-e2e",
|
||||||
|
kFp8DynamicTokenSym)])
|
||||||
@pytest.mark.parametrize("do_fusion", [True, False])
|
@pytest.mark.parametrize("do_fusion", [True, False])
|
||||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
|
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
|
||||||
reason="Only test on CUDA")
|
reason="Only test on CUDA")
|
||||||
def test_fix_functionalization(model: str, do_fusion: bool):
|
def test_fix_functionalization(model: str, quant_key: QuantKey,
|
||||||
|
do_fusion: bool):
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
|
|
||||||
config = CompilationConfig.PassConfig(enable_fusion=do_fusion,
|
config = CompilationConfig.PassConfig(enable_fusion=do_fusion,
|
||||||
@ -78,8 +82,9 @@ def test_fix_functionalization(model: str, do_fusion: bool):
|
|||||||
|
|
||||||
# OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion,
|
# OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion,
|
||||||
# and replaced by fused quantized ops in RMS_QUANT_OPS.
|
# and replaced by fused quantized ops in RMS_QUANT_OPS.
|
||||||
ops = OPS_IN_MODEL + (RMS_QUANT_OPS["static_fp8"]
|
rms_ops = [FUSED_OPS[(quant_key, True)], FUSED_OPS[(quant_key, False)]
|
||||||
if do_fusion else [RMS_OP])
|
] if do_fusion else [RMS_OP]
|
||||||
|
ops = OPS_IN_MODEL + rms_ops
|
||||||
|
|
||||||
for op in ops:
|
for op in ops:
|
||||||
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
|
||||||
|
@ -3,8 +3,9 @@ import torch
|
|||||||
from compressed_tensors.quantization import FP8_DTYPE
|
from compressed_tensors.quantization import FP8_DTYPE
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.compilation.fusion import (FusionPass, find_auto_fn,
|
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
|
||||||
find_auto_fn_maybe)
|
FusionPass, QuantKey)
|
||||||
|
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
|
||||||
from vllm.compilation.reshapes import RedundantReshapesPass
|
from vllm.compilation.reshapes import RedundantReshapesPass
|
||||||
from vllm.config import CompilationConfig
|
from vllm.config import CompilationConfig
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
@ -16,24 +17,37 @@ from .backend import TestBackend
|
|||||||
|
|
||||||
class TestModel(torch.nn.Module):
|
class TestModel(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, hidden_size: int, eps: float, *args, **kwargs):
|
def __init__(self, hidden_size: int, eps: float, static: bool, *args,
|
||||||
|
**kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
|
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
|
||||||
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(4)]
|
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
|
||||||
|
if static:
|
||||||
|
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
|
||||||
|
else:
|
||||||
|
self.scale = [None for _ in range(2)]
|
||||||
self.w = [
|
self.w = [
|
||||||
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
|
||||||
for _ in range(2)
|
for _ in range(2)
|
||||||
]
|
]
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
resid = torch.relu(x)
|
resid = torch.sqrt(x)
|
||||||
y = self.norm[0](x)
|
y = self.norm[0](x)
|
||||||
|
|
||||||
x2 = apply_fp8_linear(y, self.w[0], self.scale[0], self.scale[1])
|
x2 = apply_fp8_linear(y,
|
||||||
|
self.w[0],
|
||||||
|
self.wscale[0],
|
||||||
|
self.scale[0],
|
||||||
|
use_per_token_if_dynamic=True)
|
||||||
# make sure resid is used for replacement to work
|
# make sure resid is used for replacement to work
|
||||||
y2, resid = self.norm[1](x2, resid)
|
y2, resid = self.norm[1](x2, resid)
|
||||||
|
|
||||||
x3 = apply_fp8_linear(y2, self.w[1], self.scale[2], self.scale[3])
|
x3 = apply_fp8_linear(y2,
|
||||||
|
self.w[1],
|
||||||
|
self.wscale[1],
|
||||||
|
self.scale[1],
|
||||||
|
use_per_token_if_dynamic=True)
|
||||||
y3, resid = self.norm[2](x3, resid) # use resid here
|
y3, resid = self.norm[2](x3, resid) # use resid here
|
||||||
return y3
|
return y3
|
||||||
|
|
||||||
@ -42,14 +56,13 @@ class TestModel(torch.nn.Module):
|
|||||||
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096])
|
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096])
|
||||||
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
|
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
|
||||||
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
|
||||||
|
@pytest.mark.parametrize("static", [True, False])
|
||||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
|
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
|
||||||
reason="Only test on CUDA")
|
reason="Only test on CUDA")
|
||||||
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps):
|
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static):
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
torch.set_default_dtype(torch.float16)
|
torch.set_default_dtype(dtype)
|
||||||
|
torch.manual_seed(1)
|
||||||
if eps != 1e-5:
|
|
||||||
pytest.skip("Only test eps=1e-5 for now")
|
|
||||||
|
|
||||||
# Reshape pass is needed for the fusion pass to work
|
# Reshape pass is needed for the fusion pass to work
|
||||||
config = CompilationConfig.PassConfig(enable_fusion=True,
|
config = CompilationConfig.PassConfig(enable_fusion=True,
|
||||||
@ -58,7 +71,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps):
|
|||||||
fusion_pass = FusionPass.instance(config)
|
fusion_pass = FusionPass.instance(config)
|
||||||
|
|
||||||
backend = TestBackend(reshape_pass, fusion_pass)
|
backend = TestBackend(reshape_pass, fusion_pass)
|
||||||
model = TestModel(hidden_size, eps)
|
model = TestModel(hidden_size, eps, static)
|
||||||
|
|
||||||
# First dimension dynamic
|
# First dimension dynamic
|
||||||
x = torch.rand(num_tokens, hidden_size)
|
x = torch.rand(num_tokens, hidden_size)
|
||||||
@ -69,16 +82,28 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps):
|
|||||||
model2 = torch.compile(model, backend=backend)
|
model2 = torch.compile(model, backend=backend)
|
||||||
result2 = model2(x)
|
result2 = model2(x)
|
||||||
|
|
||||||
# Check that it gives the same answer
|
# Higher tol for dynamic, even higher for bfloat16
|
||||||
torch.testing.assert_close(result, result2, atol=1e-3, rtol=1e-3)
|
if static:
|
||||||
|
ATOL, RTOL = (1e-3, 1e-3)
|
||||||
|
elif dtype == torch.float16:
|
||||||
|
ATOL, RTOL = (2e-3, 2e-3)
|
||||||
|
else:
|
||||||
|
ATOL, RTOL = (1e-2, 1e-2)
|
||||||
|
|
||||||
|
torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)
|
||||||
|
|
||||||
# Check substitution worked
|
# Check substitution worked
|
||||||
pre_nodes = backend.graph_pre_pass.nodes
|
pre_nodes = backend.graph_pre_pass.nodes
|
||||||
post_nodes = backend.graph_post_pass.nodes
|
post_nodes = backend.graph_post_pass.nodes
|
||||||
|
|
||||||
rms_quant = torch.ops._C.rms_norm_static_fp8_quant.default
|
# static is per-tensor, dynamic is per-token
|
||||||
add_rms_quant = torch.ops._C.fused_add_rms_norm_static_fp8_quant.default
|
key = QuantKey(dtype=FP8_DTYPE,
|
||||||
fp8_quant = torch.ops._C.static_scaled_fp8_quant.default
|
static=static,
|
||||||
|
per_tensor=static,
|
||||||
|
symmetric=True)
|
||||||
|
rms_quant = FUSED_OPS[FusedRMSQuantKey(key, False)]
|
||||||
|
add_rms_quant = FUSED_OPS[FusedRMSQuantKey(key, True)]
|
||||||
|
fp8_quant = QUANT_OPS[key]
|
||||||
|
|
||||||
# In pre-nodes, fp8 quant should be present and fused kernels should not
|
# In pre-nodes, fp8 quant should be present and fused kernels should not
|
||||||
assert find_auto_fn_maybe(pre_nodes, rms_quant) is None
|
assert find_auto_fn_maybe(pre_nodes, rms_quant) is None
|
||||||
|
171
tests/kernels/test_fused_quant_layernorm.py
Normal file
171
tests/kernels/test_fused_quant_layernorm.py
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm._custom_ops as ops
|
||||||
|
from tests.kernels.utils import opcheck
|
||||||
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
|
||||||
|
DTYPES = [torch.bfloat16, torch.float]
|
||||||
|
QUANT_DTYPES = [torch.int8, torch.float8_e4m3fn]
|
||||||
|
VEC_HIDDEN_SIZES = range(1024, 1030)
|
||||||
|
# Avoid combinatorial explosion with full Cartesian product
|
||||||
|
NUM_TOKENS_HIDDEN_SIZES = [
|
||||||
|
*[(1, i) for i in [1, 64, *VEC_HIDDEN_SIZES, 5120, 5137]],
|
||||||
|
*[(83, i) for i in [1, 1033, 2048, 5120]],
|
||||||
|
*[(2048, i) for i in [1, 64, *VEC_HIDDEN_SIZES, 5137]],
|
||||||
|
*[(4096, i) for i in [1, 64, 5137]],
|
||||||
|
]
|
||||||
|
|
||||||
|
ADD_RESIDUAL = [False, True]
|
||||||
|
SCALE_UBS = [True, False]
|
||||||
|
SEEDS = [0]
|
||||||
|
CUDA_DEVICES = [
|
||||||
|
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||||
|
]
|
||||||
|
|
||||||
|
EPS = 1e-6
|
||||||
|
|
||||||
|
## Helpers
|
||||||
|
|
||||||
|
|
||||||
|
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
|
||||||
|
return torch.as_tensor(x, dtype=torch.float32, device='cuda')
|
||||||
|
|
||||||
|
|
||||||
|
def ref_rms_norm(rms_norm_layer: RMSNorm,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor]) \
|
||||||
|
-> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
if residual is not None:
|
||||||
|
residual = residual.clone()
|
||||||
|
out, residual = rms_norm_layer.forward_native(x, residual)
|
||||||
|
else:
|
||||||
|
out = rms_norm_layer.forward_native(x)
|
||||||
|
|
||||||
|
return out, residual
|
||||||
|
|
||||||
|
|
||||||
|
def ref_dynamic_per_token_quant(rms_norm_layer: RMSNorm,
|
||||||
|
x: torch.Tensor,
|
||||||
|
quant_dtype: torch.dtype,
|
||||||
|
residual: Optional[torch.Tensor],
|
||||||
|
scale_ub: Optional[torch.Tensor]) \
|
||||||
|
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
if scale_ub is not None:
|
||||||
|
assert quant_dtype == torch.float8_e4m3fn
|
||||||
|
|
||||||
|
# Norm
|
||||||
|
torch_out, residual = ref_rms_norm(rms_norm_layer, x, residual)
|
||||||
|
|
||||||
|
# Quant
|
||||||
|
if quant_dtype == torch.float8_e4m3fn:
|
||||||
|
torch_out, scales = ops.scaled_fp8_quant(torch_out,
|
||||||
|
scale_ub=scale_ub,
|
||||||
|
use_per_token_if_dynamic=True)
|
||||||
|
else:
|
||||||
|
assert quant_dtype == torch.int8
|
||||||
|
torch_out, scales = ops.scaled_int8_quant(torch_out)
|
||||||
|
|
||||||
|
return torch_out, scales, residual
|
||||||
|
|
||||||
|
|
||||||
|
def ref_impl(rms_norm_layer: RMSNorm,
|
||||||
|
x: torch.Tensor,
|
||||||
|
quant_dtype: torch.dtype,
|
||||||
|
residual: Optional[torch.Tensor],
|
||||||
|
scale_ub: Optional[torch.Tensor]) \
|
||||||
|
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
return ref_dynamic_per_token_quant(rms_norm_layer, x, quant_dtype,
|
||||||
|
residual, scale_ub)
|
||||||
|
|
||||||
|
|
||||||
|
def ops_dynamic_per_token_quant(weight: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
quant_dtype: torch.dtype,
|
||||||
|
residual: Optional[torch.Tensor],
|
||||||
|
scale_ub: Optional[torch.Tensor]) \
|
||||||
|
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
if residual is not None:
|
||||||
|
residual = residual.clone()
|
||||||
|
out, scales = ops.rms_norm_dynamic_per_token_quant(x, weight, EPS,
|
||||||
|
quant_dtype, scale_ub,
|
||||||
|
residual)
|
||||||
|
return out, scales, residual
|
||||||
|
|
||||||
|
|
||||||
|
def ops_impl(weight: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
quant_dtype: torch.dtype,
|
||||||
|
residual: Optional[torch.Tensor],
|
||||||
|
scale_ub: Optional[torch.Tensor]) \
|
||||||
|
-> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
return ops_dynamic_per_token_quant(weight, x, quant_dtype, residual,
|
||||||
|
scale_ub)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("num_tokens, hidden_size", NUM_TOKENS_HIDDEN_SIZES)
|
||||||
|
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
|
||||||
|
@pytest.mark.parametrize("scale_ub", SCALE_UBS)
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("quant_dtype", QUANT_DTYPES)
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_rms_norm(
|
||||||
|
num_tokens: int,
|
||||||
|
hidden_size: int,
|
||||||
|
add_residual: bool,
|
||||||
|
scale_ub: bool,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
quant_dtype: torch.dtype,
|
||||||
|
seed: int,
|
||||||
|
device: str,
|
||||||
|
) -> None:
|
||||||
|
torch.random.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
|
||||||
|
if scale_ub is not None and quant_dtype != torch.float8_e4m3fn:
|
||||||
|
# skip
|
||||||
|
return
|
||||||
|
|
||||||
|
layer = RMSNorm(hidden_size, EPS).to(dtype=dtype)
|
||||||
|
|
||||||
|
# Make weights
|
||||||
|
layer.weight.data.normal_(mean=1.0, std=0.1)
|
||||||
|
|
||||||
|
# Make inputs
|
||||||
|
scale = 1 / (hidden_size)
|
||||||
|
x = torch.randn(num_tokens, hidden_size, dtype=dtype) * scale
|
||||||
|
residual = torch.randn_like(x) * scale if add_residual else None
|
||||||
|
if scale_ub is not None:
|
||||||
|
rms_x, _ = ref_rms_norm(layer, x, residual)
|
||||||
|
scale_ub = torch.mean(rms_x).to(dtype=torch.float32, device='cuda')
|
||||||
|
|
||||||
|
ref_out, ref_scales, ref_residual = \
|
||||||
|
ref_impl(layer, x, quant_dtype, residual, scale_ub)
|
||||||
|
ops_out, ops_scales, ops_residual = \
|
||||||
|
ops_impl(layer.weight, x, quant_dtype, residual, scale_ub)
|
||||||
|
|
||||||
|
assert ref_out.dtype == quant_dtype
|
||||||
|
assert ops_out.dtype == quant_dtype
|
||||||
|
assert torch.allclose(ref_scales, ops_scales)
|
||||||
|
if quant_dtype == torch.int8:
|
||||||
|
# big atol to account for round-off errors.
|
||||||
|
assert torch.allclose(ref_out, ops_out, atol=1)
|
||||||
|
else:
|
||||||
|
assert torch.allclose(ref_out.to(dtype=torch.float32),
|
||||||
|
ops_out.to(dtype=torch.float32))
|
||||||
|
if add_residual:
|
||||||
|
assert torch.allclose(ref_residual, ops_residual)
|
||||||
|
|
||||||
|
output = torch.empty_like(x, dtype=quant_dtype)
|
||||||
|
scales = torch.empty((x.numel() // x.shape[-1], 1),
|
||||||
|
device=x.device,
|
||||||
|
dtype=torch.float32)
|
||||||
|
|
||||||
|
opcheck(torch.ops._C.rms_norm_dynamic_per_token_quant,
|
||||||
|
(output, x, layer.weight, scales, 1e-5, scale_ub, residual))
|
@ -249,6 +249,26 @@ def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
|
|||||||
block_table_bound)
|
block_table_bound)
|
||||||
|
|
||||||
|
|
||||||
|
# fused quant layer norm ops
|
||||||
|
def rms_norm_dynamic_per_token_quant(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
epsilon: float,
|
||||||
|
quant_dtype: torch.dtype,
|
||||||
|
scale_ub: Optional[torch.Tensor] = None,
|
||||||
|
residual: Optional[torch.Tensor] = None
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
output = torch.empty_like(input, dtype=quant_dtype)
|
||||||
|
scales = torch.empty((input.numel() // input.shape[-1], 1),
|
||||||
|
device=input.device,
|
||||||
|
dtype=torch.float32)
|
||||||
|
|
||||||
|
torch.ops._C.rms_norm_dynamic_per_token_quant(output, input, weight,
|
||||||
|
scales, epsilon, scale_ub,
|
||||||
|
residual)
|
||||||
|
return output, scales
|
||||||
|
|
||||||
|
|
||||||
# quantization ops
|
# quantization ops
|
||||||
# awq
|
# awq
|
||||||
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
|
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
|
||||||
|
@ -6,7 +6,8 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
|||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from .vllm_inductor_pass import VllmInductorPass, is_func
|
from .fx_utils import is_func
|
||||||
|
from .vllm_inductor_pass import VllmInductorPass
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -53,14 +54,16 @@ class FixFunctionalizationPass(VllmInductorPass):
|
|||||||
self.insert_defunctionalized(graph, node)
|
self.insert_defunctionalized(graph, node)
|
||||||
self._remove(node)
|
self._remove(node)
|
||||||
|
|
||||||
# These 2 replacements avoid the most copies for LLaMa.
|
# rms_norm replacements avoid the most copies for LLaMa.
|
||||||
elif at_target == torch.ops._C.fused_add_rms_norm.default:
|
elif at_target == torch.ops._C.fused_add_rms_norm.default:
|
||||||
mutated_args = {1: 'input', 2: 'residual'}
|
mutated_args = {1: 'input', 2: 'residual'}
|
||||||
self.defunctionalize(graph, node, mutated_args)
|
self.defunctionalize(graph, node, mutated_args)
|
||||||
elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
|
elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501
|
||||||
mutated_args = {1: 'result', 2: 'residual'}
|
mutated_args = {1: 'result', 2: 'residual'}
|
||||||
self.defunctionalize(graph, node, mutated_args)
|
self.defunctionalize(graph, node, mutated_args)
|
||||||
|
elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501
|
||||||
|
mutated_args = {1: 'result', 2: 'scale', 3: 'residual'}
|
||||||
|
self.defunctionalize(graph, node, mutated_args)
|
||||||
elif at_target in [
|
elif at_target in [
|
||||||
torch.ops._C.rms_norm.default,
|
torch.ops._C.rms_norm.default,
|
||||||
torch.ops._C.rms_norm_static_fp8_quant.default
|
torch.ops._C.rms_norm_static_fp8_quant.default
|
||||||
|
@ -1,28 +1,196 @@
|
|||||||
import operator
|
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
|
||||||
from typing import Iterable, List, Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch._inductor.pattern_matcher as pm
|
||||||
|
# TODO(luka) use vllm.utils once #10836 landed
|
||||||
|
from compressed_tensors.quantization import FP8_DTYPE
|
||||||
|
from torch import fx
|
||||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||||
from torch._inductor.pattern_matcher import (Match, PatternMatcherPass,
|
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||||
fwd_only, register_replacement)
|
from torch._ops import OpOverload
|
||||||
|
|
||||||
from vllm.config import CompilationConfig
|
from vllm.config import CompilationConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from .vllm_inductor_pass import VllmInductorPass, is_func
|
from .fx_utils import find_getitem_maybe
|
||||||
|
from .multi_output_match import MultiOutputMatch
|
||||||
|
from .vllm_inductor_pass import VllmInductorPass
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def rms_pattern_static(result: torch.Tensor, result_rms: torch.Tensor,
|
def empty_bf16(*args, **kwargs):
|
||||||
|
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
|
||||||
|
|
||||||
|
|
||||||
|
def empty_fp32(*args, **kwargs):
|
||||||
|
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
|
||||||
|
|
||||||
|
|
||||||
|
RMS_OP = torch.ops._C.rms_norm.default
|
||||||
|
RMS_ADD_OP = torch.ops._C.fused_add_rms_norm.default
|
||||||
|
|
||||||
|
|
||||||
|
class QuantKey(NamedTuple):
|
||||||
|
"""
|
||||||
|
Named tuple for identifying the type of quantization.
|
||||||
|
dtype: quantized data type
|
||||||
|
static: static quantization if True, dynamic if False
|
||||||
|
per_tensor: per-tensor quantization if True, per-token if False
|
||||||
|
symmetric: symmetric if True, asymmetric if False
|
||||||
|
"""
|
||||||
|
dtype: torch.dtype
|
||||||
|
static: bool
|
||||||
|
per_tensor: bool = True
|
||||||
|
symmetric: bool = True
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return (f"QuantKey({'static' if self.static else 'dynamic'},"
|
||||||
|
f"{fx.graph.dtype_abbrs[self.dtype]},"
|
||||||
|
f"{'per_tensor' if self.per_tensor else 'per_token'},"
|
||||||
|
f"{'a' if not self.symmetric else ''}symmetric)")
|
||||||
|
|
||||||
|
|
||||||
|
kFp8StaticTensorSym = QuantKey(FP8_DTYPE, True, True, True)
|
||||||
|
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, False, True, True)
|
||||||
|
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, False, False, True)
|
||||||
|
|
||||||
|
QUANT_OPS: Dict[QuantKey, OpOverload] = {
|
||||||
|
kFp8StaticTensorSym: torch.ops._C.static_scaled_fp8_quant.default, # noqa
|
||||||
|
kFp8DynamicTensorSym:
|
||||||
|
torch.ops._C.dynamic_scaled_fp8_quant.default, # noqa
|
||||||
|
kFp8DynamicTokenSym:
|
||||||
|
torch.ops._C.dynamic_per_token_scaled_fp8_quant.default, # noqa
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class FusedRMSQuantKey(NamedTuple):
|
||||||
|
"""
|
||||||
|
Named tuple for identifying the type of RMSNorm + quant fusion.
|
||||||
|
quant: type of quantization
|
||||||
|
fused_add: does the op also perform the residual add
|
||||||
|
"""
|
||||||
|
quant: QuantKey
|
||||||
|
fused_add: bool
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return (f"FusedQuantKey({self.quant}, with"
|
||||||
|
f"{'' if self.fused_add else 'out'} residual)")
|
||||||
|
|
||||||
|
|
||||||
|
FUSED_OPS: Dict[FusedRMSQuantKey, OpOverload] = {
|
||||||
|
FusedRMSQuantKey(kFp8StaticTensorSym, False):
|
||||||
|
torch.ops._C.rms_norm_static_fp8_quant.default, # noqa
|
||||||
|
FusedRMSQuantKey(kFp8StaticTensorSym, True):
|
||||||
|
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, # noqa
|
||||||
|
FusedRMSQuantKey(kFp8DynamicTokenSym, False):
|
||||||
|
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa
|
||||||
|
FusedRMSQuantKey(kFp8DynamicTokenSym, True):
|
||||||
|
torch.ops._C.rms_norm_dynamic_per_token_quant.default, # noqa
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class QuantMultiOutputMatch(MultiOutputMatch):
|
||||||
|
|
||||||
|
def __init__(self, match: pm.Match, quant_op, fused_op):
|
||||||
|
super().__init__(match)
|
||||||
|
assert isinstance(quant_op, OpOverload)
|
||||||
|
assert isinstance(fused_op, OpOverload)
|
||||||
|
self.QUANT_OP = quant_op # in-place quant op
|
||||||
|
self.FUSED_OP = fused_op # in-place fused quant op
|
||||||
|
|
||||||
|
def insert_fused_node(self, fused_return_mapping: Dict[int, Tuple[fx.Node,
|
||||||
|
int]],
|
||||||
|
**kwargs):
|
||||||
|
"""
|
||||||
|
This utility function inserts an auto-functionalized node for FUSED_OP.
|
||||||
|
It also correctly sets its meta value and rebinds the users of the
|
||||||
|
unfused nodes to use the fused node instead.
|
||||||
|
|
||||||
|
:param fused_return_mapping: A dictionary, mapping from getitem indices
|
||||||
|
of the fused node result to a tuple of the old node and a getitem index.
|
||||||
|
:param kwargs: kwargs that get directly forwarded to the auto_fn node
|
||||||
|
|
||||||
|
Example:
|
||||||
|
If we want to replace this graph:
|
||||||
|
_, x1, x2 = auto_fn(op1)
|
||||||
|
_, y1, y2 = auto_fn(op2)
|
||||||
|
|
||||||
|
with
|
||||||
|
_, x1, y2, x2 = auto_fn(FUSED_OP)
|
||||||
|
|
||||||
|
we would call:
|
||||||
|
insert_fused_node({1: (op1_node, 1), 2: (op2_node, 2), 3: (op1_node, 2)}
|
||||||
|
|
||||||
|
Note that the 0th element is None for auto-functionalized in-place ops.
|
||||||
|
Hence, others appear 1-indexed.
|
||||||
|
"""
|
||||||
|
fused_node = self.insert_auto_fn(self.FUSED_OP, kwargs)
|
||||||
|
indices = fused_return_mapping.keys()
|
||||||
|
getitem_nodes = self.insert_getitems(fused_node, indices)
|
||||||
|
|
||||||
|
# Prepare the meta value, use a list so it's mutable
|
||||||
|
meta_val = [None] * (max(indices) + 1)
|
||||||
|
|
||||||
|
# Iterate through elements of the tuple produced by fused_node
|
||||||
|
for idx, getitem_node in zip(indices, getitem_nodes):
|
||||||
|
old_node, old_idx = fused_return_mapping[idx]
|
||||||
|
|
||||||
|
# If the old value was never used, the old_getitem might not exist
|
||||||
|
old_getitem = find_getitem_maybe(old_node, old_idx)
|
||||||
|
if old_getitem is not None:
|
||||||
|
# Rebind the users of match getitem nodes to use the new nodes.
|
||||||
|
# The old nodes will be removed by DCE at the end of the pass.
|
||||||
|
old_getitem.replace_all_uses_with(getitem_node)
|
||||||
|
getitem_node.meta["val"] = old_getitem.meta["val"]
|
||||||
|
|
||||||
|
# Extract the appropriate meta value
|
||||||
|
# It is present even if the getitem node does not exist
|
||||||
|
meta_val[idx] = old_node.meta["val"][old_idx]
|
||||||
|
|
||||||
|
# Fix the meta value on the new fused node
|
||||||
|
fused_node.meta["val"] = tuple(meta_val)
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNormQuantPattern:
|
||||||
|
|
||||||
|
def __init__(self, epsilon: float, key: FusedRMSQuantKey):
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.quant_dtype = key.quant.dtype
|
||||||
|
|
||||||
|
assert key.quant in QUANT_OPS, \
|
||||||
|
f"unsupported quantization scheme {key.quant}"
|
||||||
|
self.QUANT_OP = QUANT_OPS[key.quant]
|
||||||
|
|
||||||
|
assert key in FUSED_OPS, \
|
||||||
|
f"unsupported fused rmsnorm+quant op for {key}"
|
||||||
|
self.FUSED_OP = FUSED_OPS[key]
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
epsilon: float,
|
||||||
|
quant_dtype: torch.dtype,
|
||||||
|
symmetric=True):
|
||||||
|
fused_key = FusedRMSQuantKey(fused_add=False,
|
||||||
|
quant=QuantKey(dtype=quant_dtype,
|
||||||
|
static=True,
|
||||||
|
per_tensor=True,
|
||||||
|
symmetric=symmetric))
|
||||||
|
super().__init__(epsilon, fused_key)
|
||||||
|
|
||||||
|
def register(self, pm_pass: PatternMatcherPass):
|
||||||
|
# Cannot use methods, as the self argument affects tracing
|
||||||
|
def pattern(result: torch.Tensor, result_rms: torch.Tensor,
|
||||||
input: torch.Tensor, weight: torch.Tensor,
|
input: torch.Tensor, weight: torch.Tensor,
|
||||||
scale: torch.Tensor):
|
scale: torch.Tensor):
|
||||||
at1 = auto_functionalized(torch.ops._C.rms_norm.default,
|
at1 = auto_functionalized(RMS_OP,
|
||||||
result=result_rms,
|
result=result_rms,
|
||||||
input=input,
|
input=input,
|
||||||
weight=weight,
|
weight=weight,
|
||||||
epsilon=1e-5)
|
epsilon=self.epsilon)
|
||||||
at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default,
|
at2 = auto_functionalized(self.QUANT_OP,
|
||||||
result=result,
|
result=result,
|
||||||
input=at1[1],
|
input=at1[1],
|
||||||
scale=scale)
|
scale=scale)
|
||||||
@ -30,30 +198,56 @@ def rms_pattern_static(result: torch.Tensor, result_rms: torch.Tensor,
|
|||||||
# result
|
# result
|
||||||
return at2[1]
|
return at2[1]
|
||||||
|
|
||||||
|
def replacement(result: torch.Tensor, result_rms: torch.Tensor,
|
||||||
def rms_replacement_static(result: torch.Tensor, result_rms: torch.Tensor,
|
|
||||||
input: torch.Tensor, weight: torch.Tensor,
|
input: torch.Tensor, weight: torch.Tensor,
|
||||||
scale: torch.Tensor):
|
scale: torch.Tensor):
|
||||||
at = auto_functionalized(torch.ops._C.rms_norm_static_fp8_quant.default,
|
at = auto_functionalized(self.FUSED_OP,
|
||||||
result=result,
|
result=result,
|
||||||
input=input,
|
input=input,
|
||||||
weight=weight,
|
weight=weight,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
epsilon=1e-5)
|
epsilon=self.epsilon)
|
||||||
|
|
||||||
# result
|
# result
|
||||||
return at[1]
|
return at[1]
|
||||||
|
|
||||||
|
inputs = [
|
||||||
|
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
|
||||||
|
empty_bf16(5, 4), # result_rms
|
||||||
|
empty_bf16(5, 4), # input
|
||||||
|
empty_bf16(1, 5), # weight
|
||||||
|
empty_fp32(1, 1) # scale
|
||||||
|
]
|
||||||
|
|
||||||
def rms_pattern_residual_static(result: torch.Tensor, input: torch.Tensor,
|
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only,
|
||||||
|
pm_pass)
|
||||||
|
|
||||||
|
|
||||||
|
class FusedAddRMSNormStaticQuantPattern(RMSNormQuantPattern):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
epsilon: float,
|
||||||
|
quant_dtype: torch.dtype,
|
||||||
|
symmetric=True):
|
||||||
|
key = FusedRMSQuantKey(fused_add=True,
|
||||||
|
quant=QuantKey(dtype=quant_dtype,
|
||||||
|
static=True,
|
||||||
|
per_tensor=True,
|
||||||
|
symmetric=symmetric))
|
||||||
|
super().__init__(epsilon, key)
|
||||||
|
|
||||||
|
def register(self, pm_pass: PatternMatcherPass,
|
||||||
|
record_match: Callable[[MultiOutputMatch], bool]):
|
||||||
|
|
||||||
|
def pattern(result: torch.Tensor, input: torch.Tensor,
|
||||||
residual: torch.Tensor, weight: torch.Tensor,
|
residual: torch.Tensor, weight: torch.Tensor,
|
||||||
scale: torch.Tensor):
|
scale: torch.Tensor):
|
||||||
at = auto_functionalized(torch.ops._C.fused_add_rms_norm.default,
|
at = auto_functionalized(RMS_ADD_OP,
|
||||||
input=input,
|
input=input,
|
||||||
residual=residual,
|
residual=residual,
|
||||||
weight=weight,
|
weight=weight,
|
||||||
epsilon=1e-5)
|
epsilon=self.epsilon)
|
||||||
at1 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default,
|
at1 = auto_functionalized(self.QUANT_OP,
|
||||||
result=result,
|
result=result,
|
||||||
input=at[1],
|
input=at[1],
|
||||||
scale=scale)
|
scale=scale)
|
||||||
@ -61,69 +255,263 @@ def rms_pattern_residual_static(result: torch.Tensor, input: torch.Tensor,
|
|||||||
# result, residual
|
# result, residual
|
||||||
return at1[1], at[2]
|
return at1[1], at[2]
|
||||||
|
|
||||||
|
def replacement(result: torch.Tensor, input: torch.Tensor,
|
||||||
def rms_replacement_residual_static(result: torch.Tensor, input: torch.Tensor,
|
residual: torch.Tensor, weight: torch.Tensor,
|
||||||
residual: torch.Tensor,
|
scale: torch.Tensor):
|
||||||
weight: torch.Tensor, scale: torch.Tensor):
|
at = auto_functionalized(self.FUSED_OP,
|
||||||
at = auto_functionalized(
|
|
||||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant.default,
|
|
||||||
result=result,
|
result=result,
|
||||||
input=input,
|
input=input,
|
||||||
residual=residual,
|
residual=residual,
|
||||||
weight=weight,
|
weight=weight,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
epsilon=1e-5)
|
epsilon=self.epsilon)
|
||||||
|
|
||||||
# result, residual
|
# result, residual
|
||||||
return at[1], at[2]
|
return at[1], at[2]
|
||||||
|
|
||||||
|
inputs = [
|
||||||
|
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
|
||||||
|
empty_bf16(5, 4), # input
|
||||||
|
empty_bf16(5, 4), # residual
|
||||||
|
empty_bf16(1, 5), # weight
|
||||||
|
empty_fp32(1, 1) # scale
|
||||||
|
]
|
||||||
|
|
||||||
def empty_bf16(*args, **kwargs):
|
pm.register_replacement(
|
||||||
return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda")
|
pattern,
|
||||||
|
replacement,
|
||||||
|
inputs,
|
||||||
|
pm.fwd_only,
|
||||||
|
pm_pass,
|
||||||
|
extra_check=lambda m: record_match(
|
||||||
|
self.Match(m, self.QUANT_OP, self.FUSED_OP)))
|
||||||
|
|
||||||
|
class Match(QuantMultiOutputMatch):
|
||||||
|
|
||||||
|
def process(self):
|
||||||
|
# Find the nodes in the match that we need to rebind
|
||||||
|
rms_node = self.find_auto_fn(RMS_ADD_OP)
|
||||||
|
quant_node = self.find_auto_fn(self.QUANT_OP)
|
||||||
|
|
||||||
|
assert len(rms_node.users) == 2
|
||||||
|
assert len(quant_node.users) == 1
|
||||||
|
|
||||||
|
# First, insert a new auto_functionalized node for the fused op,
|
||||||
|
# as well as getitem nodes to extract the result and residual.
|
||||||
|
# The auto_fn node returns a tuple of (None, result, residual).
|
||||||
|
#
|
||||||
|
# The resulting graph looks like this:
|
||||||
|
# at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa
|
||||||
|
# result_node_new = at[1]
|
||||||
|
# residual_node_new = at[2]
|
||||||
|
with self.inserting_after_match():
|
||||||
|
# Missing epsilon, scalars cannot be inputs to the pattern
|
||||||
|
kwargs = self.match.kwargs.copy()
|
||||||
|
|
||||||
|
# 0 is always None
|
||||||
|
fused_return_mapping = {1: (quant_node, 1), 2: (rms_node, 2)}
|
||||||
|
self.insert_fused_node(fused_return_mapping,
|
||||||
|
epsilon=rms_node.kwargs["epsilon"],
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
def empty_fp8(*args, **kwargs):
|
class RMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||||
fp8 = torch.float8_e4m3fn
|
|
||||||
return torch.empty(*args, **kwargs, dtype=fp8, device="cuda")
|
def __init__(self,
|
||||||
|
epsilon: float,
|
||||||
|
quant_dtype: torch.dtype,
|
||||||
|
per_tensor: bool,
|
||||||
|
symmetric=True):
|
||||||
|
key = FusedRMSQuantKey(fused_add=False,
|
||||||
|
quant=QuantKey(dtype=quant_dtype,
|
||||||
|
static=False,
|
||||||
|
per_tensor=per_tensor,
|
||||||
|
symmetric=symmetric))
|
||||||
|
super().__init__(epsilon, key)
|
||||||
|
|
||||||
|
def register(self, pm_pass: PatternMatcherPass,
|
||||||
|
record_match: Callable[[MultiOutputMatch], bool]):
|
||||||
|
|
||||||
|
def pattern(result: torch.Tensor, result_rms: torch.Tensor,
|
||||||
|
input: torch.Tensor, weight: torch.Tensor,
|
||||||
|
scale: torch.Tensor):
|
||||||
|
at1 = auto_functionalized(RMS_OP,
|
||||||
|
result=result_rms,
|
||||||
|
input=input,
|
||||||
|
weight=weight,
|
||||||
|
epsilon=self.epsilon)
|
||||||
|
at2 = auto_functionalized(self.QUANT_OP,
|
||||||
|
result=result,
|
||||||
|
input=at1[1],
|
||||||
|
scale=scale,
|
||||||
|
scale_ub=None)
|
||||||
|
|
||||||
|
# result, scale
|
||||||
|
return at2[1], at2[2]
|
||||||
|
|
||||||
|
def replacement(result: torch.Tensor, result_rms: torch.Tensor,
|
||||||
|
input: torch.Tensor, weight: torch.Tensor,
|
||||||
|
scale: torch.Tensor):
|
||||||
|
at = auto_functionalized(self.FUSED_OP,
|
||||||
|
result=result,
|
||||||
|
input=input,
|
||||||
|
weight=weight,
|
||||||
|
scale=scale,
|
||||||
|
epsilon=self.epsilon,
|
||||||
|
scale_ub=None,
|
||||||
|
residual=None)
|
||||||
|
|
||||||
|
# result, scale
|
||||||
|
return at[1], at[2]
|
||||||
|
|
||||||
|
inputs = [
|
||||||
|
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
|
||||||
|
empty_bf16(5, 4), # result_rms
|
||||||
|
empty_bf16(5, 4), # input
|
||||||
|
empty_bf16(1, 5), # weight
|
||||||
|
empty_fp32(1, 1) # scale
|
||||||
|
]
|
||||||
|
|
||||||
|
pm.register_replacement(
|
||||||
|
pattern,
|
||||||
|
replacement,
|
||||||
|
inputs,
|
||||||
|
pm.fwd_only,
|
||||||
|
pm_pass,
|
||||||
|
extra_check=lambda m: record_match(
|
||||||
|
self.Match(m, self.QUANT_OP, self.FUSED_OP)))
|
||||||
|
|
||||||
|
class Match(QuantMultiOutputMatch):
|
||||||
|
|
||||||
|
def process(self):
|
||||||
|
# Find the nodes in the match that we need to rebind
|
||||||
|
rms_node = self.find_auto_fn(RMS_OP)
|
||||||
|
quant_node = self.find_auto_fn(self.QUANT_OP)
|
||||||
|
|
||||||
|
assert len(rms_node.users) == 1
|
||||||
|
assert len(quant_node.users) == 2
|
||||||
|
|
||||||
|
# First, insert a new auto_functionalized node for the fused op,
|
||||||
|
# as well as getitem nodes to extract the result and scale.
|
||||||
|
# The auto_fn node returns a tuple of (None, result, scale).
|
||||||
|
#
|
||||||
|
# The resulting graph looks like this:
|
||||||
|
# at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa
|
||||||
|
# result_node_new = at[1]
|
||||||
|
# scale_node_new = at[2]
|
||||||
|
with self.inserting_after_match():
|
||||||
|
# Missing epsilon, scalars cannot be inputs to the pattern
|
||||||
|
kwargs = self.match.kwargs.copy()
|
||||||
|
del kwargs["result_rms"] # not used in the fused op
|
||||||
|
|
||||||
|
fused_return_mapping = {1: (quant_node, 1), 2: (quant_node, 2)}
|
||||||
|
self.insert_fused_node(
|
||||||
|
fused_return_mapping,
|
||||||
|
epsilon=rms_node.kwargs["epsilon"],
|
||||||
|
scale_ub=None, # not used but required
|
||||||
|
residual=None, # not used but required
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
def empty_fp32(*args, **kwargs):
|
class FusedAddRMSNormDynamicQuantPattern(RMSNormQuantPattern):
|
||||||
return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda")
|
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
epsilon: float,
|
||||||
|
quant_dtype: torch.dtype,
|
||||||
|
per_tensor: bool = True,
|
||||||
|
symmetric=True):
|
||||||
|
key = FusedRMSQuantKey(fused_add=True,
|
||||||
|
quant=QuantKey(dtype=quant_dtype,
|
||||||
|
static=False,
|
||||||
|
per_tensor=per_tensor,
|
||||||
|
symmetric=symmetric))
|
||||||
|
super().__init__(epsilon, key)
|
||||||
|
|
||||||
# Utilities for post-processing multi-output matches
|
def register(self, pm_pass: PatternMatcherPass,
|
||||||
|
record_match: Callable[[MultiOutputMatch], bool]):
|
||||||
|
|
||||||
|
def pattern(result: torch.Tensor, input: torch.Tensor,
|
||||||
|
residual: torch.Tensor, weight: torch.Tensor,
|
||||||
|
scale: torch.Tensor):
|
||||||
|
at = auto_functionalized(RMS_ADD_OP,
|
||||||
|
input=input,
|
||||||
|
residual=residual,
|
||||||
|
weight=weight,
|
||||||
|
epsilon=self.epsilon)
|
||||||
|
at1 = auto_functionalized(self.QUANT_OP,
|
||||||
|
result=result,
|
||||||
|
input=at[1],
|
||||||
|
scale=scale,
|
||||||
|
scale_ub=None)
|
||||||
|
|
||||||
# Returns the first auto_functionalized node with the given op (if it exists)
|
# result, residual, scale
|
||||||
def find_auto_fn_maybe(nodes: Iterable[torch.fx.Node],
|
return at1[1], at[2], at1[2]
|
||||||
op) -> Optional[torch.fx.Node]:
|
|
||||||
for node in nodes:
|
|
||||||
if is_func(node, auto_functionalized) and node.args[0] == op: # noqa
|
|
||||||
return node
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
def replacement(result: torch.Tensor, input: torch.Tensor,
|
||||||
|
residual: torch.Tensor, weight: torch.Tensor,
|
||||||
|
scale: torch.Tensor):
|
||||||
|
at = auto_functionalized(self.FUSED_OP,
|
||||||
|
result=result,
|
||||||
|
input=input,
|
||||||
|
weight=weight,
|
||||||
|
scale=scale,
|
||||||
|
epsilon=self.epsilon,
|
||||||
|
scale_ub=None,
|
||||||
|
residual=residual)
|
||||||
|
|
||||||
# Returns the first auto_functionalized node with the given op
|
# result, residual, scale
|
||||||
def find_auto_fn(nodes: Iterable[torch.fx.Node], op) -> torch.fx.Node:
|
return at[1], at[3], at[2]
|
||||||
node = find_auto_fn_maybe(nodes, op)
|
|
||||||
assert node is not None, f"Could not find {op} in nodes {nodes}"
|
|
||||||
return node
|
|
||||||
|
|
||||||
|
inputs = [
|
||||||
|
torch.empty(5, 4, device="cuda", dtype=self.quant_dtype), # result
|
||||||
|
empty_bf16(5, 4), # input
|
||||||
|
empty_bf16(5, 4), # residual
|
||||||
|
empty_bf16(1, 5), # weight
|
||||||
|
empty_fp32(1, 1) # scale
|
||||||
|
]
|
||||||
|
|
||||||
# Returns the getitem node that extracts the idx-th element from node
|
pm.register_replacement(
|
||||||
# (if it exists)
|
pattern,
|
||||||
def find_getitem_maybe(node: torch.fx.Node,
|
replacement,
|
||||||
idx: int) -> Optional[torch.fx.Node]:
|
inputs,
|
||||||
for user in node.users:
|
pm.fwd_only,
|
||||||
if is_func(user, operator.getitem) and user.args[1] == idx:
|
pm_pass,
|
||||||
return user
|
extra_check=lambda m: record_match(
|
||||||
return None
|
self.Match(m, self.QUANT_OP, self.FUSED_OP)))
|
||||||
|
|
||||||
|
class Match(QuantMultiOutputMatch):
|
||||||
|
|
||||||
# Returns the getitem node that extracts the idx-th element from node
|
def process(self):
|
||||||
def find_getitem(node: torch.fx.Node, idx: int) -> torch.fx.Node:
|
# Find the nodes in the match that we need to rebind
|
||||||
ret = find_getitem_maybe(node, idx)
|
rms_node = self.find_auto_fn(RMS_ADD_OP)
|
||||||
assert ret is not None, f"Could not find getitem {idx} in node {node}"
|
quant_node = self.find_auto_fn(self.QUANT_OP)
|
||||||
return ret
|
|
||||||
|
assert len(rms_node.users) == 2
|
||||||
|
assert len(quant_node.users) == 2
|
||||||
|
|
||||||
|
# First, insert a new auto_functionalized node for the fused op,
|
||||||
|
# as well as getitem nodes to extract result, scale, and residual.
|
||||||
|
# The auto_fn node returns a tuple (None, result, scale, residual).
|
||||||
|
#
|
||||||
|
# The resulting graph looks like this:
|
||||||
|
# at = auto_functionalized(torch.ops._C.rms_norm_dynamic_per_token_quant.default, ...) # noqa
|
||||||
|
# result_node_new = at[1]
|
||||||
|
# scale_node_new = at[2]
|
||||||
|
# residual_node_new = at[3]
|
||||||
|
with self.inserting_after_match():
|
||||||
|
# Missing epsilon, scalars cannot be inputs to the pattern
|
||||||
|
kwargs = self.match.kwargs.copy()
|
||||||
|
|
||||||
|
fused_return_mapping = {
|
||||||
|
1: (quant_node, 1), # result
|
||||||
|
2: (quant_node, 2), # scale
|
||||||
|
3: (rms_node, 2), # residual
|
||||||
|
}
|
||||||
|
self.insert_fused_node(
|
||||||
|
fused_return_mapping,
|
||||||
|
epsilon=rms_node.kwargs["epsilon"],
|
||||||
|
scale_ub=None, # not used but required
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
class FusionPass(VllmInductorPass):
|
class FusionPass(VllmInductorPass):
|
||||||
@ -158,41 +546,39 @@ class FusionPass(VllmInductorPass):
|
|||||||
"FusionPass singleton instance already exists"
|
"FusionPass singleton instance already exists"
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.matches: List[Match] = []
|
self.matches: List[MultiOutputMatch] = []
|
||||||
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
self.patterns: PatternMatcherPass = PatternMatcherPass(
|
||||||
pass_name="fusion_pass")
|
pass_name="fusion_pass")
|
||||||
|
|
||||||
# Fuse rms_norm + static_scaled_fp8_quant into
|
for epsilon in [1e-5, 1e-6]:
|
||||||
# rms_norm_static_fp8_quant
|
# Fuse rms_norm + static fp8 quant
|
||||||
inputs = [
|
RMSNormStaticQuantPattern(epsilon,
|
||||||
empty_fp8(5, 4),
|
FP8_DTYPE).register(self.patterns)
|
||||||
empty_bf16(5, 4),
|
|
||||||
empty_bf16(5, 4),
|
|
||||||
empty_bf16(1, 5),
|
|
||||||
empty_fp32(1, 1)
|
|
||||||
]
|
|
||||||
register_replacement(rms_pattern_static, rms_replacement_static,
|
|
||||||
inputs, fwd_only, self.patterns)
|
|
||||||
|
|
||||||
# Fuse fused_add_rms_norm + static_scaled_fp8_quant into
|
# Matches for patterns below have 2 or more outputs,
|
||||||
# fused_add_rms_norm_static_fp8_quant
|
# so we need to process them manually (see process_matches)
|
||||||
# Because pattern has 2 outputs, we need to manually process the match
|
|
||||||
# (see process_matches)
|
# Fuse rms_norm + static fp8 quant
|
||||||
inputs = [
|
FusedAddRMSNormStaticQuantPattern(epsilon, FP8_DTYPE).register(
|
||||||
empty_fp8(5, 4),
|
self.patterns, self.record_match)
|
||||||
empty_bf16(5, 4),
|
|
||||||
empty_bf16(5, 4),
|
# Fuse rms_norm + dynamic per-token fp8 quant
|
||||||
empty_bf16(1, 5),
|
RMSNormDynamicQuantPattern(epsilon, FP8_DTYPE,
|
||||||
empty_fp32(1, 1)
|
per_tensor=False).register(
|
||||||
]
|
self.patterns, self.record_match)
|
||||||
register_replacement(rms_pattern_residual_static,
|
|
||||||
rms_replacement_residual_static,
|
# Fuse fused_add_rms_norm + dynamic per-token fp8 quant
|
||||||
inputs,
|
FusedAddRMSNormDynamicQuantPattern(epsilon,
|
||||||
fwd_only,
|
FP8_DTYPE,
|
||||||
|
per_tensor=False).register(
|
||||||
self.patterns,
|
self.patterns,
|
||||||
extra_check=lambda m: self.record_match(m))
|
self.record_match)
|
||||||
|
|
||||||
def record_match(self, match: Match) -> bool:
|
# WARNING: This is a hack to clear the pattern matcher cache
|
||||||
|
# and allow multiple values of epsilon.
|
||||||
|
torch._inductor.pattern_matcher._seen_patterns.clear()
|
||||||
|
|
||||||
|
def record_match(self, match: MultiOutputMatch) -> bool:
|
||||||
# Hijack the extra_check to record the match and
|
# Hijack the extra_check to record the match and
|
||||||
# save it for post-processing.
|
# save it for post-processing.
|
||||||
self.matches.append(match)
|
self.matches.append(match)
|
||||||
@ -200,83 +586,20 @@ class FusionPass(VllmInductorPass):
|
|||||||
# Return False to prevent automatic replacement.
|
# Return False to prevent automatic replacement.
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def process_matches(self, graph: torch.fx.Graph):
|
def process_matches(self, graph: fx.Graph):
|
||||||
"""
|
"""
|
||||||
Manually process multi-output matches and replace them with fused nodes.
|
Manually process multi-output matches and replace them with fused nodes.
|
||||||
This is necessary because the automatic replacement for multi-output
|
See MultiOutputMatch for more details.
|
||||||
matches is broken: https://github.com/pytorch/pytorch/issues/137280
|
|
||||||
"""
|
"""
|
||||||
for match in self.matches:
|
for match in self.matches:
|
||||||
# To avoid use-before-definition errors, insert replacement nodes
|
match.process()
|
||||||
# after the last node in the match.
|
|
||||||
# match.nodes is not guaranteed to be sorted.
|
|
||||||
# Find the last node in the match.
|
|
||||||
for last_node_in_match in reversed(graph.nodes):
|
|
||||||
if last_node_in_match in match.nodes:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
raise ValueError("No nodes in graph")
|
|
||||||
|
|
||||||
# Insert a new auto_functionalized node for the fused operation,
|
|
||||||
# as well as getitem nodes to extract the result and residual.
|
|
||||||
# The auto_functionalized node returns a tuple of
|
|
||||||
# (None, result, residual) - None is the function return value.
|
|
||||||
# The resulting graph looks like this:
|
|
||||||
# at = auto_functionalized(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default, ...) # noqa
|
|
||||||
# result_node_new = at[1]
|
|
||||||
# residual_node_new = at[2]
|
|
||||||
with graph.inserting_after(last_node_in_match):
|
|
||||||
kwargs = match.kwargs
|
|
||||||
kwargs["epsilon"] = 1e-5 # Currently hard-coded in RMSNorm
|
|
||||||
|
|
||||||
fused_node = graph.call_function(
|
|
||||||
auto_functionalized,
|
|
||||||
(torch.ops._C.fused_add_rms_norm_static_fp8_quant.default,
|
|
||||||
),
|
|
||||||
kwargs=kwargs)
|
|
||||||
|
|
||||||
graph.inserting_after(fused_node)
|
|
||||||
result_node_new = graph.call_function(operator.getitem,
|
|
||||||
(fused_node, 1))
|
|
||||||
residual_node_new = graph.call_function(
|
|
||||||
operator.getitem, (fused_node, 2))
|
|
||||||
|
|
||||||
# Last part of replacement is rebinding the users of nodes in the
|
|
||||||
# match to use the new nodes.
|
|
||||||
|
|
||||||
# Find the nodes in the match that we need to rebind
|
|
||||||
rms_node = find_auto_fn(match.nodes,
|
|
||||||
torch.ops._C.fused_add_rms_norm.default)
|
|
||||||
quant_node = find_auto_fn(
|
|
||||||
match.nodes, torch.ops._C.static_scaled_fp8_quant.default)
|
|
||||||
|
|
||||||
assert len(rms_node.users) == 2
|
|
||||||
assert len(quant_node.users) == 1
|
|
||||||
|
|
||||||
# meta["val"] is used by de-functionalization and has to contain the
|
|
||||||
# value of the node (tuple of tensors) that would be returned by the
|
|
||||||
# functionalized node during tracing.
|
|
||||||
|
|
||||||
rms_tup = rms_node.meta["val"]
|
|
||||||
quant_tup = quant_node.meta["val"]
|
|
||||||
|
|
||||||
# The result of fused_node must be a tuple with the first element
|
|
||||||
# None (the function return value) and the remaining elements
|
|
||||||
# representing the mutated inputs.
|
|
||||||
fused_tup = (None, quant_tup[1], rms_tup[1], rms_tup[2])
|
|
||||||
fused_node.meta["val"] = fused_tup
|
|
||||||
|
|
||||||
# Find the getitem nodes and replace their uses with the new nodes.
|
|
||||||
# The old nodes will be removed by DCE at the end of the pass.
|
|
||||||
find_getitem(rms_node, 2).replace_all_uses_with(residual_node_new)
|
|
||||||
find_getitem(quant_node, 1).replace_all_uses_with(result_node_new)
|
|
||||||
|
|
||||||
# Finally, remove matched nodes
|
# Finally, remove matched nodes
|
||||||
graph.eliminate_dead_code()
|
graph.eliminate_dead_code()
|
||||||
assert all(node not in graph.nodes for match in self.matches
|
assert all(node not in graph.nodes for match in self.matches
|
||||||
for node in match.nodes)
|
for node in match.match.nodes)
|
||||||
|
|
||||||
def __call__(self, graph: torch.fx.Graph):
|
def __call__(self, graph: fx.Graph):
|
||||||
self.begin()
|
self.begin()
|
||||||
self.dump_graph(graph, "before_fusion")
|
self.dump_graph(graph, "before_fusion")
|
||||||
|
|
||||||
|
42
vllm/compilation/fx_utils.py
Normal file
42
vllm/compilation/fx_utils.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import operator
|
||||||
|
from typing import Iterable, Optional
|
||||||
|
|
||||||
|
from torch import fx
|
||||||
|
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||||
|
from torch._ops import OpOverload
|
||||||
|
|
||||||
|
|
||||||
|
def is_func(node: fx.Node, target) -> bool:
|
||||||
|
return node.op == "call_function" and node.target == target
|
||||||
|
|
||||||
|
|
||||||
|
# Returns the first auto_functionalized node with the given op (if it exists)
|
||||||
|
def find_auto_fn_maybe(nodes: Iterable[fx.Node],
|
||||||
|
op: OpOverload) -> Optional[fx.Node]:
|
||||||
|
for node in nodes:
|
||||||
|
if is_func(node, auto_functionalized) and node.args[0] == op: # noqa
|
||||||
|
return node
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Returns the first auto_functionalized node with the given op
|
||||||
|
def find_auto_fn(nodes: Iterable[fx.Node], op: OpOverload) -> fx.Node:
|
||||||
|
node = find_auto_fn_maybe(nodes, op)
|
||||||
|
assert node is not None, f"Could not find {op} in nodes {nodes}"
|
||||||
|
return node
|
||||||
|
|
||||||
|
|
||||||
|
# Returns the getitem node that extracts the idx-th element from node
|
||||||
|
# (if it exists)
|
||||||
|
def find_getitem_maybe(node: fx.Node, idx: int) -> Optional[fx.Node]:
|
||||||
|
for user in node.users:
|
||||||
|
if is_func(user, operator.getitem) and user.args[1] == idx:
|
||||||
|
return user
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# Returns the getitem node that extracts the idx-th element from node
|
||||||
|
def find_getitem(node: fx.Node, idx: int) -> fx.Node:
|
||||||
|
ret = find_getitem_maybe(node, idx)
|
||||||
|
assert ret is not None, f"Could not find getitem {idx} in node {node}"
|
||||||
|
return ret
|
105
vllm/compilation/multi_output_match.py
Normal file
105
vllm/compilation/multi_output_match.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
import abc
|
||||||
|
import operator
|
||||||
|
from abc import abstractmethod
|
||||||
|
from typing import Iterable, List, Tuple
|
||||||
|
|
||||||
|
from torch import fx
|
||||||
|
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||||
|
from torch._inductor import pattern_matcher as pm
|
||||||
|
from torch._ops import OpOverload
|
||||||
|
|
||||||
|
from vllm.compilation.fx_utils import find_auto_fn
|
||||||
|
|
||||||
|
|
||||||
|
class MultiOutputMatch(abc.ABC):
|
||||||
|
"""
|
||||||
|
This class provides utilities to process multi-output matches and
|
||||||
|
manually insert replacements.
|
||||||
|
|
||||||
|
This is necessary because the automatic replacement for multi-output
|
||||||
|
matches is broken: https://github.com/pytorch/pytorch/issues/137280
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, match: pm.Match):
|
||||||
|
self.match = match
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def process(self):
|
||||||
|
"""
|
||||||
|
Process a multi-output match and manually insert the replacement.
|
||||||
|
|
||||||
|
This method should:
|
||||||
|
1. Insert the replacement nodes after the last node in the match.
|
||||||
|
2. Rebind the users of nodes in the match to use the new nodes.
|
||||||
|
3. Set meta["val"] for de-functionalization.
|
||||||
|
|
||||||
|
The result of an auto-functionalized node is a tuple of tensors.
|
||||||
|
The first element is the return value of the function, usually None.
|
||||||
|
The remaining elements are the mutated args of the function.
|
||||||
|
|
||||||
|
All auto-functionalized nodes must contain a proper meta["val"],
|
||||||
|
as it is used by de-functionalization. meta["val"] has to contain the
|
||||||
|
value of the node (tuple of tensors) that would be returned by the
|
||||||
|
functionalized node during tracing.
|
||||||
|
|
||||||
|
Existing nodes in the graph all have this property set, but we have
|
||||||
|
to set it manually for new nodes we insert.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
# op schema: foo(a: Tensor!, b: Tensor, c: Tensor!) -> None
|
||||||
|
at = auto_functionalized(torch.ops._C.foo.default, a, b, c)
|
||||||
|
# at.meta["val"] = (None, a, c)
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def nodes(self) -> List[fx.Node]:
|
||||||
|
return self.match.nodes
|
||||||
|
|
||||||
|
@property
|
||||||
|
def graph(self) -> fx.Graph:
|
||||||
|
return self.match.graph
|
||||||
|
|
||||||
|
def find_auto_fn(self, op) -> fx.Node:
|
||||||
|
"""
|
||||||
|
Find the first auto_functionalized node with the given op in the match.
|
||||||
|
"""
|
||||||
|
return find_auto_fn(self.nodes, op)
|
||||||
|
|
||||||
|
def inserting_after_match(self):
|
||||||
|
"""
|
||||||
|
Insert nodes after the last node in the match.
|
||||||
|
This is done to avoid use-before-definition errors after inserting
|
||||||
|
replacement nodes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# match.nodes is not guaranteed to be sorted.
|
||||||
|
# Find the last node in the match.
|
||||||
|
for last_node_in_match in reversed(self.graph.nodes):
|
||||||
|
if last_node_in_match in self.match.nodes:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise ValueError("No nodes in graph")
|
||||||
|
|
||||||
|
return self.graph.inserting_after(last_node_in_match)
|
||||||
|
|
||||||
|
def insert_getitems(self, tuple_node: fx.Node,
|
||||||
|
indices: Iterable[int]) -> Tuple[fx.Node, ...]:
|
||||||
|
"""
|
||||||
|
Insert operator.getitem nodes to extract elements from a tuple node.
|
||||||
|
|
||||||
|
:param tuple_node: The tuple node to extract elements from.
|
||||||
|
:param indices: The indices of the elements to extract.
|
||||||
|
:return: Tuple of the new getitem nodes, corresponding to the indices.
|
||||||
|
"""
|
||||||
|
with self.graph.inserting_after(tuple_node):
|
||||||
|
return tuple(
|
||||||
|
self.graph.call_function(operator.getitem, (tuple_node, idx))
|
||||||
|
for idx in indices)
|
||||||
|
|
||||||
|
def insert_auto_fn(self, op: OpOverload, kwargs):
|
||||||
|
"""
|
||||||
|
Insert an auto_functionalized node with the given op and kwargs.
|
||||||
|
"""
|
||||||
|
return self.graph.call_function(auto_functionalized, (op, ),
|
||||||
|
kwargs=kwargs)
|
@ -5,7 +5,8 @@ from torch import SymInt
|
|||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from .vllm_inductor_pass import VllmInductorPass, is_func
|
from .fx_utils import is_func
|
||||||
|
from .vllm_inductor_pass import VllmInductorPass
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
@ -16,10 +16,6 @@ from .inductor_pass import InductorPass
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def is_func(node: torch.fx.Node, target) -> bool:
|
|
||||||
return node.op == "call_function" and node.target == target
|
|
||||||
|
|
||||||
|
|
||||||
class VllmInductorPass(InductorPass):
|
class VllmInductorPass(InductorPass):
|
||||||
"""
|
"""
|
||||||
An inductor pass with access to vLLM PassConfig.
|
An inductor pass with access to vLLM PassConfig.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user