[Feature][ROCm]Enable fusion pass for torch.compile on ROCm (#15050)
Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
parent
effc5d24fa
commit
e85829450d
@ -30,9 +30,6 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
|||||||
fp8_type* __restrict__ out, float* __restrict__ scale,
|
fp8_type* __restrict__ out, float* __restrict__ scale,
|
||||||
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
|
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
|
||||||
const int hidden_size) {
|
const int hidden_size) {
|
||||||
float const min_scaling_factor =
|
|
||||||
1.0f / (fp8_e4m3_adjusted_max_v<fp8_type> * 512.f);
|
|
||||||
|
|
||||||
int const tid = threadIdx.x;
|
int const tid = threadIdx.x;
|
||||||
int const token_idx = blockIdx.x;
|
int const token_idx = blockIdx.x;
|
||||||
|
|
||||||
@ -67,8 +64,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
|||||||
token_scale = block_absmax_val_maybe;
|
token_scale = block_absmax_val_maybe;
|
||||||
}
|
}
|
||||||
// token scale computation
|
// token scale computation
|
||||||
token_scale = max(token_scale / fp8_e4m3_adjusted_max_v<fp8_type>,
|
token_scale = max(token_scale / quant_type_max_v<fp8_type>,
|
||||||
min_scaling_factor);
|
min_scaling_factor<fp8_type>::val());
|
||||||
scale[token_idx] = token_scale;
|
scale[token_idx] = token_scale;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
@ -1,20 +1,12 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "quantization/vectorization.cuh"
|
#include "quantization/vectorization.cuh"
|
||||||
|
#include "quantization/utils.cuh"
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <c10/core/ScalarType.h>
|
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
#include <c10/util/Float8_e4m3fn.h>
|
|
||||||
#define MAYBE_HOST_DEVICE C10_HOST_DEVICE
|
|
||||||
#else
|
|
||||||
#include <ATen/hip/HIPContext.h>
|
|
||||||
#include <c10/util/Float8_e4m3fn.h>
|
|
||||||
#include <c10/util/Float8_e4m3fnuz.h>
|
|
||||||
#include "amd/quant_utils.cuh"
|
#include "amd/quant_utils.cuh"
|
||||||
// ROCm doesn't seem to need C10_HOST_DEVICE for static constexpr
|
|
||||||
#define MAYBE_HOST_DEVICE
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Determines the preferred FP8 type for the current platform.
|
// Determines the preferred FP8 type for the current platform.
|
||||||
@ -31,29 +23,6 @@ static bool is_fp8_ocp() {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
struct fp8_e4m3_adjusted_max;
|
|
||||||
|
|
||||||
template <>
|
|
||||||
struct fp8_e4m3_adjusted_max<c10::Float8_e4m3fn> {
|
|
||||||
static constexpr c10::Float8_e4m3fn val() {
|
|
||||||
return std::numeric_limits<c10::Float8_e4m3fn>::max();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Using the default max value from pytorch (240.0 0x7F) will cause accuracy
|
|
||||||
// issues when running dynamic quantization. Here use 224.0 0x7E for rocm.
|
|
||||||
template <>
|
|
||||||
struct fp8_e4m3_adjusted_max<c10::Float8_e4m3fnuz> {
|
|
||||||
static constexpr c10::Float8_e4m3fnuz val() {
|
|
||||||
return c10::Float8_e4m3fnuz(0x7E, c10::Float8_e4m3fnuz::from_bits());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
MAYBE_HOST_DEVICE static constexpr T fp8_e4m3_adjusted_max_v =
|
|
||||||
fp8_e4m3_adjusted_max<T>::val();
|
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||||
@ -76,8 +45,8 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
|
|||||||
x = val / scale;
|
x = val / scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
float r = fmax(-fp8_e4m3_adjusted_max_v<fp8_type>,
|
float r =
|
||||||
fmin(x, fp8_e4m3_adjusted_max_v<fp8_type>));
|
fmax(-quant_type_max_v<fp8_type>, fmin(x, quant_type_max_v<fp8_type>));
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
return static_cast<fp8_type>(r);
|
return static_cast<fp8_type>(r);
|
||||||
#else
|
#else
|
||||||
@ -123,7 +92,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
|
|||||||
// Finally, since cache[0] contains the maximum for this thread block,
|
// Finally, since cache[0] contains the maximum for this thread block,
|
||||||
// atomically write the max to the target location
|
// atomically write the max to the target location
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
atomicMaxFloat(scale, cache[0] / fp8_e4m3_adjusted_max_v<fp8_type>);
|
atomicMaxFloat(scale, cache[0] / quant_type_max_v<fp8_type>);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -14,8 +14,7 @@ __device__ void rms_norm_dynamic_per_token_quant_vec(
|
|||||||
float* __restrict__ scales, // [num_tokens]
|
float* __restrict__ scales, // [num_tokens]
|
||||||
scalar_t const* __restrict__ input, // [..., hidden_size]
|
scalar_t const* __restrict__ input, // [..., hidden_size]
|
||||||
scalar_t const* __restrict__ weight, // [hidden_size]
|
scalar_t const* __restrict__ weight, // [hidden_size]
|
||||||
float const* scale_ub, float const var_epsilon,
|
float const* scale_ub, float const var_epsilon, int32_t const hidden_size,
|
||||||
float const min_scaling_factor, int32_t const hidden_size,
|
|
||||||
scalar_t* __restrict__ residual = nullptr) {
|
scalar_t* __restrict__ residual = nullptr) {
|
||||||
float rms = 0.0f;
|
float rms = 0.0f;
|
||||||
float token_scale = 0.0f;
|
float token_scale = 0.0f;
|
||||||
@ -27,8 +26,8 @@ __device__ void rms_norm_dynamic_per_token_quant_vec(
|
|||||||
// Compute scale
|
// Compute scale
|
||||||
vllm::vectorized::compute_dynamic_per_token_scales<scalar_t, scalar_out_t,
|
vllm::vectorized::compute_dynamic_per_token_scales<scalar_t, scalar_out_t,
|
||||||
has_residual>(
|
has_residual>(
|
||||||
&token_scale, scales, input, weight, rms, scale_ub, min_scaling_factor,
|
&token_scale, scales, input, weight, rms, scale_ub, hidden_size,
|
||||||
hidden_size, residual);
|
residual);
|
||||||
|
|
||||||
// RMS Norm + Quant
|
// RMS Norm + Quant
|
||||||
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
|
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
|
||||||
@ -50,8 +49,7 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel(
|
|||||||
float* __restrict__ scales, // [num_tokens]
|
float* __restrict__ scales, // [num_tokens]
|
||||||
scalar_t const* __restrict__ input, // [..., hidden_size]
|
scalar_t const* __restrict__ input, // [..., hidden_size]
|
||||||
scalar_t const* __restrict__ weight, // [hidden_size]
|
scalar_t const* __restrict__ weight, // [hidden_size]
|
||||||
float const* scale_ub, float const var_epsilon,
|
float const* scale_ub, float const var_epsilon, int32_t const hidden_size,
|
||||||
float const min_scaling_factor, int32_t const hidden_size,
|
|
||||||
scalar_t* __restrict__ residual = nullptr) {
|
scalar_t* __restrict__ residual = nullptr) {
|
||||||
// For vectorization, token_input and token_output pointers need to be
|
// For vectorization, token_input and token_output pointers need to be
|
||||||
// aligned at 8-byte and 4-byte addresses respectively.
|
// aligned at 8-byte and 4-byte addresses respectively.
|
||||||
@ -60,8 +58,8 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel(
|
|||||||
if (can_vectorize) {
|
if (can_vectorize) {
|
||||||
return rms_norm_dynamic_per_token_quant_vec<scalar_t, scalar_out_t,
|
return rms_norm_dynamic_per_token_quant_vec<scalar_t, scalar_out_t,
|
||||||
has_residual>(
|
has_residual>(
|
||||||
out, scales, input, weight, scale_ub, var_epsilon, min_scaling_factor,
|
out, scales, input, weight, scale_ub, var_epsilon, hidden_size,
|
||||||
hidden_size, residual);
|
residual);
|
||||||
}
|
}
|
||||||
|
|
||||||
float rms = 0.0f;
|
float rms = 0.0f;
|
||||||
@ -72,8 +70,8 @@ __global__ void rms_norm_dynamic_per_token_quant_kernel(
|
|||||||
var_epsilon, residual);
|
var_epsilon, residual);
|
||||||
// Compute Scale
|
// Compute Scale
|
||||||
vllm::compute_dynamic_per_token_scales<scalar_t, scalar_out_t, has_residual>(
|
vllm::compute_dynamic_per_token_scales<scalar_t, scalar_out_t, has_residual>(
|
||||||
&token_scale, scales, input, weight, rms, scale_ub, min_scaling_factor,
|
&token_scale, scales, input, weight, rms, scale_ub, hidden_size,
|
||||||
hidden_size, residual);
|
residual);
|
||||||
|
|
||||||
// RMS Norm + Quant
|
// RMS Norm + Quant
|
||||||
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
|
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
|
||||||
@ -105,11 +103,6 @@ void rms_norm_dynamic_per_token_quant_dispatch(
|
|||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
|
|
||||||
const float min_scaling_factor =
|
|
||||||
out.dtype() == torch::kInt8
|
|
||||||
? std::numeric_limits<float>::epsilon()
|
|
||||||
: 1.0f / (std::numeric_limits<c10::Float8_e4m3fn>::max() * 512.f);
|
|
||||||
|
|
||||||
if (residual.has_value()) {
|
if (residual.has_value()) {
|
||||||
VLLM_DISPATCH_QUANT_TYPES(
|
VLLM_DISPATCH_QUANT_TYPES(
|
||||||
out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] {
|
out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] {
|
||||||
@ -119,8 +112,7 @@ void rms_norm_dynamic_per_token_quant_dispatch(
|
|||||||
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
|
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
|
||||||
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
|
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
|
||||||
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
||||||
var_epsilon, min_scaling_factor, hidden_size,
|
var_epsilon, hidden_size, residual->data_ptr<scalar_in_t>());
|
||||||
residual->data_ptr<scalar_in_t>());
|
|
||||||
});
|
});
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
@ -132,7 +124,7 @@ void rms_norm_dynamic_per_token_quant_dispatch(
|
|||||||
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
|
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
|
||||||
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
|
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
|
||||||
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
||||||
var_epsilon, min_scaling_factor, hidden_size, nullptr);
|
var_epsilon, hidden_size, nullptr);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
#include "quantization/vectorization.cuh"
|
#include "quantization/vectorization.cuh"
|
||||||
|
#include "quantization/utils.cuh"
|
||||||
#include "quant_conversions.cuh"
|
#include "quant_conversions.cuh"
|
||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
@ -51,11 +52,11 @@ __device__ void compute_dynamic_per_token_scales(
|
|||||||
float* __restrict__ token_scale, float* __restrict__ all_token_scales,
|
float* __restrict__ token_scale, float* __restrict__ all_token_scales,
|
||||||
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
|
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
|
||||||
float const rms, float const* __restrict__ scale_ub,
|
float const rms, float const* __restrict__ scale_ub,
|
||||||
float const min_scaling_factor, int32_t const hidden_size,
|
int32_t const hidden_size,
|
||||||
scalar_t const* __restrict__ residual = nullptr) {
|
scalar_t const* __restrict__ residual = nullptr) {
|
||||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||||
;
|
;
|
||||||
constexpr scalar_out_t qmax{std::numeric_limits<scalar_out_t>::max()};
|
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
|
||||||
|
|
||||||
float block_absmax_val_maybe = 0.0f;
|
float block_absmax_val_maybe = 0.0f;
|
||||||
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
for (auto i = threadIdx.x; i < hidden_size; i += blockDim.x) {
|
||||||
@ -83,7 +84,7 @@ __device__ void compute_dynamic_per_token_scales(
|
|||||||
scale = block_absmax_val_maybe;
|
scale = block_absmax_val_maybe;
|
||||||
}
|
}
|
||||||
// token scale computation
|
// token scale computation
|
||||||
scale = max(scale / qmax, min_scaling_factor);
|
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
|
||||||
s_token_scale = scale; // Shared memory store
|
s_token_scale = scale; // Shared memory store
|
||||||
all_token_scales[blockIdx.x] = scale; // Global output store
|
all_token_scales[blockIdx.x] = scale; // Global output store
|
||||||
}
|
}
|
||||||
@ -184,7 +185,7 @@ __device__ void compute_dynamic_per_token_scales(
|
|||||||
float* __restrict__ token_scale, float* __restrict__ all_token_scales,
|
float* __restrict__ token_scale, float* __restrict__ all_token_scales,
|
||||||
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
|
scalar_t const* __restrict__ input, scalar_t const* __restrict__ weight,
|
||||||
float const rms, float const* __restrict__ scale_ub,
|
float const rms, float const* __restrict__ scale_ub,
|
||||||
float const min_scaling_factor, int32_t const hidden_size,
|
int32_t const hidden_size,
|
||||||
scalar_t const* __restrict__ residual = nullptr) {
|
scalar_t const* __restrict__ residual = nullptr) {
|
||||||
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
int64_t const token_offset = blockIdx.x * static_cast<int64_t>(hidden_size);
|
||||||
;
|
;
|
||||||
@ -200,7 +201,7 @@ __device__ void compute_dynamic_per_token_scales(
|
|||||||
reinterpret_cast<vec4_t<scalar_t> const*>(&residual[token_offset]);
|
reinterpret_cast<vec4_t<scalar_t> const*>(&residual[token_offset]);
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr scalar_out_t qmax{std::numeric_limits<scalar_out_t>::max()};
|
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
|
||||||
|
|
||||||
int32_t const num_vec_elems = hidden_size >> 2;
|
int32_t const num_vec_elems = hidden_size >> 2;
|
||||||
float block_absmax_val_maybe = 0.0f;
|
float block_absmax_val_maybe = 0.0f;
|
||||||
@ -248,7 +249,7 @@ __device__ void compute_dynamic_per_token_scales(
|
|||||||
scale = block_absmax_val_maybe;
|
scale = block_absmax_val_maybe;
|
||||||
}
|
}
|
||||||
// token scale computation
|
// token scale computation
|
||||||
scale = max(scale / qmax, min_scaling_factor);
|
scale = max(scale / qmax, min_scaling_factor<scalar_out_t>::val());
|
||||||
s_token_scale = scale; // shared memory store
|
s_token_scale = scale; // shared memory store
|
||||||
all_token_scales[blockIdx.x] = scale; // global output store
|
all_token_scales[blockIdx.x] = scale; // global output store
|
||||||
}
|
}
|
||||||
|
@ -33,8 +33,8 @@ static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) {
|
|||||||
|
|
||||||
template <typename fp8_type>
|
template <typename fp8_type>
|
||||||
static __device__ __forceinline__ fp8_type float_to_fp8(float const x) {
|
static __device__ __forceinline__ fp8_type float_to_fp8(float const x) {
|
||||||
float const r = fmax(-fp8_e4m3_adjusted_max_v<fp8_type>,
|
float const r =
|
||||||
fmin(x, fp8_e4m3_adjusted_max_v<fp8_type>));
|
fmax(-quant_type_max_v<fp8_type>, fmin(x, quant_type_max_v<fp8_type>));
|
||||||
return static_cast<fp8_type>(r);
|
return static_cast<fp8_type>(r);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
59
csrc/quantization/utils.cuh
Normal file
59
csrc/quantization/utils.cuh
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Quantization utilities including:
|
||||||
|
* Adjusted maximum values for qtypes.
|
||||||
|
* Minimum scaling factors for qtypes.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <torch/types.h>
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
#include <c10/util/Float8_e4m3fn.h>
|
||||||
|
#define MAYBE_HOST_DEVICE C10_HOST_DEVICE
|
||||||
|
#else
|
||||||
|
#include <ATen/hip/HIPContext.h>
|
||||||
|
#include <c10/util/Float8_e4m3fn.h>
|
||||||
|
#include <c10/util/Float8_e4m3fnuz.h>
|
||||||
|
// ROCm doesn't seem to need C10_HOST_DEVICE for static constexpr
|
||||||
|
#define MAYBE_HOST_DEVICE
|
||||||
|
#endif
|
||||||
|
|
||||||
|
template <typename T,
|
||||||
|
typename = std::enable_if_t<std::is_same_v<T, c10::Float8_e4m3fn> ||
|
||||||
|
std::is_same_v<T, c10::Float8_e4m3fnuz> ||
|
||||||
|
std::is_same_v<T, int8_t>>>
|
||||||
|
struct quant_type_max {
|
||||||
|
static constexpr T val() { return std::numeric_limits<T>::max(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
// Using the default max value from pytorch (240.0 0x7F) will cause accuracy
|
||||||
|
// issues when running dynamic quantization. Here use 224.0 0x7E for rocm.
|
||||||
|
template <>
|
||||||
|
struct quant_type_max<c10::Float8_e4m3fnuz> {
|
||||||
|
static constexpr c10::Float8_e4m3fnuz val() {
|
||||||
|
return c10::Float8_e4m3fnuz(0x7E, c10::Float8_e4m3fnuz::from_bits());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
MAYBE_HOST_DEVICE static constexpr T quant_type_max_v =
|
||||||
|
quant_type_max<T>::val();
|
||||||
|
|
||||||
|
template <typename T,
|
||||||
|
typename = std::enable_if_t<std::is_same_v<T, c10::Float8_e4m3fn> ||
|
||||||
|
std::is_same_v<T, c10::Float8_e4m3fnuz> ||
|
||||||
|
std::is_same_v<T, int8_t>>>
|
||||||
|
struct min_scaling_factor {
|
||||||
|
C10_DEVICE C10_ALWAYS_INLINE static float val() {
|
||||||
|
return 1.0f / (quant_type_max_v<T> * 512.0f);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct min_scaling_factor<int8_t> {
|
||||||
|
C10_DEVICE C10_ALWAYS_INLINE static float val() {
|
||||||
|
return std::numeric_limits<float>::epsilon();
|
||||||
|
}
|
||||||
|
};
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from compressed_tensors.quantization import FP8_DTYPE
|
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
import vllm.plugins
|
import vllm.plugins
|
||||||
@ -14,9 +13,12 @@ from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
|
|||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity)
|
CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .backend import TestBackend
|
from .backend import TestBackend
|
||||||
|
|
||||||
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
|
|
||||||
|
|
||||||
class TestModel(torch.nn.Module):
|
class TestModel(torch.nn.Module):
|
||||||
|
|
||||||
@ -59,8 +61,8 @@ class TestModel(torch.nn.Module):
|
|||||||
@pytest.mark.parametrize("static", [True, False])
|
@pytest.mark.parametrize("static", [True, False])
|
||||||
@pytest.mark.parametrize("cutlass_fp8_enabled",
|
@pytest.mark.parametrize("cutlass_fp8_enabled",
|
||||||
[True, False] if CUTLASS_FP8_SUPPORTED else [False])
|
[True, False] if CUTLASS_FP8_SUPPORTED else [False])
|
||||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda",
|
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
|
||||||
reason="Only test on CUDA")
|
reason="Only test on CUDA and ROCm")
|
||||||
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
|
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
|
||||||
cutlass_fp8_enabled):
|
cutlass_fp8_enabled):
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
|
@ -4,8 +4,6 @@ from typing import Callable, Dict, List, NamedTuple, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._inductor.pattern_matcher as pm
|
import torch._inductor.pattern_matcher as pm
|
||||||
# TODO(luka) use vllm.utils once #10836 landed
|
|
||||||
from compressed_tensors.quantization import FP8_DTYPE
|
|
||||||
from torch import fx
|
from torch import fx
|
||||||
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
||||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||||
@ -13,12 +11,14 @@ from torch._ops import OpOverload
|
|||||||
|
|
||||||
from vllm.config import CompilationConfig
|
from vllm.config import CompilationConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from .fx_utils import find_getitem_maybe
|
from .fx_utils import find_getitem_maybe
|
||||||
from .multi_output_match import MultiOutputMatch
|
from .multi_output_match import MultiOutputMatch
|
||||||
from .vllm_inductor_pass import VllmInductorPass
|
from .vllm_inductor_pass import VllmInductorPass
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
|
|
||||||
|
|
||||||
def empty_bf16(*args, **kwargs):
|
def empty_bf16(*args, **kwargs):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user