dynamic distpatch of fp8 kernels (#14245)
Signed-off-by: Jeff Daily <jeff.daily@amd.com>
This commit is contained in:
parent
08a1a1121d
commit
a1c8f3796c
@ -18,8 +18,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import *
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import FlexibleArgumentParser
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm(
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
) else torch.float8_e4m3fn
|
|
||||||
|
|
||||||
|
|
||||||
class BenchmarkConfig(TypedDict):
|
class BenchmarkConfig(TypedDict):
|
||||||
|
@ -6,6 +6,11 @@
|
|||||||
|
|
||||||
#include <torch/all.h>
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
// Need a special dispatch case macro since we will nest the FP8 dispatch.
|
||||||
|
// Instead of the usual 'scalar_t', this names the dispatched type 'fp8_t'.
|
||||||
|
#define AT_DISPATCH_FP8_CASE(enum_type, ...) \
|
||||||
|
AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, fp8_t, __VA_ARGS__)
|
||||||
|
|
||||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||||
@ -14,17 +19,32 @@
|
|||||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
|
||||||
|
|
||||||
// TODO(luka/varun): use FP8_TYPE macro after refactoring
|
// ROCm devices might use either fn or fnuz, so set up dispatch table for both.
|
||||||
#ifndef USE_ROCM
|
// A host-based check at runtime will create a preferred FP8 type for ROCm
|
||||||
|
// such that the correct kernel is dispatched.
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
|
||||||
|
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__)
|
||||||
|
|
||||||
|
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
|
||||||
|
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
|
||||||
|
#else
|
||||||
|
#define VLLM_DISPATCH_CASE_FP8_TYPES(...) \
|
||||||
|
AT_DISPATCH_FP8_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__)
|
||||||
|
|
||||||
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
|
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
|
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
|
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
|
||||||
#else
|
|
||||||
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
|
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
|
|
||||||
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// When using this dispatch macro, the type is 'fp8_t' not 'scalar_t'.
|
||||||
|
// See AT_DISPATCH_FP8_CASE above.
|
||||||
|
#define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \
|
||||||
|
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__))
|
||||||
|
|
||||||
#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
|
#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
|
||||||
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
|
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
|
||||||
|
|
||||||
|
@ -21,9 +21,9 @@
|
|||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
// TODO(woosuk): Further optimize this kernel.
|
// TODO(woosuk): Further optimize this kernel.
|
||||||
template <typename scalar_t>
|
template <typename scalar_t, typename fp8_type>
|
||||||
__global__ void rms_norm_static_fp8_quant_kernel(
|
__global__ void rms_norm_static_fp8_quant_kernel(
|
||||||
FP8_TYPE* __restrict__ out, // [..., hidden_size]
|
fp8_type* __restrict__ out, // [..., hidden_size]
|
||||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
const float* __restrict__ scale, // [1]
|
const float* __restrict__ scale, // [1]
|
||||||
@ -52,7 +52,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(
|
|||||||
float x = (float)input[blockIdx.x * hidden_size + idx];
|
float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||||
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
|
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
|
||||||
out[blockIdx.x * hidden_size + idx] =
|
out[blockIdx.x * hidden_size + idx] =
|
||||||
scaled_fp8_conversion<true>(out_norm, scale_inv);
|
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -60,10 +60,10 @@ __global__ void rms_norm_static_fp8_quant_kernel(
|
|||||||
Additional optimizations we can make in this case are
|
Additional optimizations we can make in this case are
|
||||||
packed and vectorized operations, which help with the
|
packed and vectorized operations, which help with the
|
||||||
memory latency bottleneck. */
|
memory latency bottleneck. */
|
||||||
template <typename scalar_t, int width>
|
template <typename scalar_t, int width, typename fp8_type>
|
||||||
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
|
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
|
||||||
fused_add_rms_norm_static_fp8_quant_kernel(
|
fused_add_rms_norm_static_fp8_quant_kernel(
|
||||||
FP8_TYPE* __restrict__ out, // [..., hidden_size]
|
fp8_type* __restrict__ out, // [..., hidden_size]
|
||||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
@ -114,7 +114,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < width; ++i) {
|
for (int i = 0; i < width; ++i) {
|
||||||
out[id * width + i] =
|
out[id * width + i] =
|
||||||
scaled_fp8_conversion<true>(float(temp.data[i]), scale_inv);
|
scaled_fp8_conversion<true, fp8_type>(float(temp.data[i]), scale_inv);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -122,10 +122,10 @@ fused_add_rms_norm_static_fp8_quant_kernel(
|
|||||||
/* Generic fused_add_rms_norm_kernel
|
/* Generic fused_add_rms_norm_kernel
|
||||||
The width field is not used here but necessary for other specializations.
|
The width field is not used here but necessary for other specializations.
|
||||||
*/
|
*/
|
||||||
template <typename scalar_t, int width>
|
template <typename scalar_t, int width, typename fp8_type>
|
||||||
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
|
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
|
||||||
fused_add_rms_norm_static_fp8_quant_kernel(
|
fused_add_rms_norm_static_fp8_quant_kernel(
|
||||||
FP8_TYPE* __restrict__ out, // [..., hidden_size]
|
fp8_type* __restrict__ out, // [..., hidden_size]
|
||||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||||
@ -158,7 +158,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
|
|||||||
float x = (float)residual[blockIdx.x * hidden_size + idx];
|
float x = (float)residual[blockIdx.x * hidden_size + idx];
|
||||||
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
|
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
|
||||||
out[blockIdx.x * hidden_size + idx] =
|
out[blockIdx.x * hidden_size + idx] =
|
||||||
scaled_fp8_conversion<true>(out_norm, scale_inv);
|
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -176,25 +176,33 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
|
|||||||
dim3 block(std::min(hidden_size, 1024));
|
dim3 block(std::min(hidden_size, 1024));
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
vllm::rms_norm_static_fp8_quant_kernel<scalar_t>
|
input.scalar_type(), "rms_norm_kernel_scalar_type", [&] {
|
||||||
<<<grid, block, 0, stream>>>(
|
VLLM_DISPATCH_FP8_TYPES(
|
||||||
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
|
out.scalar_type(), "rms_norm_kernel_fp8_type", [&] {
|
||||||
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), epsilon,
|
vllm::rms_norm_static_fp8_quant_kernel<scalar_t, fp8_t>
|
||||||
num_tokens, hidden_size);
|
<<<grid, block, 0, stream>>>(
|
||||||
});
|
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
||||||
|
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(),
|
||||||
|
epsilon, num_tokens, hidden_size);
|
||||||
|
});
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
|
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
|
||||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||||
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
|
input.scalar_type(), "fused_add_rms_norm_kernel_scalar_type", [&] { \
|
||||||
vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, width> \
|
VLLM_DISPATCH_FP8_TYPES( \
|
||||||
<<<grid, block, 0, stream>>>( \
|
out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \
|
||||||
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(), \
|
vllm::fused_add_rms_norm_static_fp8_quant_kernel<scalar_t, \
|
||||||
residual.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \
|
width, fp8_t> \
|
||||||
scale.data_ptr<float>(), epsilon, num_tokens, hidden_size); \
|
<<<grid, block, 0, stream>>>( \
|
||||||
|
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), \
|
||||||
|
residual.data_ptr<scalar_t>(), \
|
||||||
|
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), \
|
||||||
|
epsilon, num_tokens, hidden_size); \
|
||||||
|
}); \
|
||||||
});
|
});
|
||||||
|
|
||||||
void fused_add_rms_norm_static_fp8_quant(
|
void fused_add_rms_norm_static_fp8_quant(
|
||||||
torch::Tensor& out, // [..., hidden_size],
|
torch::Tensor& out, // [..., hidden_size],
|
||||||
torch::Tensor& input, // [..., hidden_size]
|
torch::Tensor& input, // [..., hidden_size]
|
||||||
|
@ -13,6 +13,28 @@ namespace vllm {
|
|||||||
namespace fp8 {
|
namespace fp8 {
|
||||||
#ifdef ENABLE_FP8
|
#ifdef ENABLE_FP8
|
||||||
|
|
||||||
|
// Use hardware cvt instruction for fp8 on rocm
|
||||||
|
template <typename fp8_type>
|
||||||
|
__device__ __forceinline__ fp8_type cvt_c10(float const r) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __forceinline__ c10::Float8_e4m3fn cvt_c10(float const r) {
|
||||||
|
return c10::Float8_e4m3fn(
|
||||||
|
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3::__default_saturation,
|
||||||
|
__hip_fp8_e4m3::__default_interpret),
|
||||||
|
c10::Float8_e4m3fn::from_bits());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__device__ __forceinline__ c10::Float8_e4m3fnuz cvt_c10(float const r) {
|
||||||
|
return c10::Float8_e4m3fnuz(
|
||||||
|
__hip_cvt_float_to_fp8(r, __hip_fp8_e4m3_fnuz::__default_saturation,
|
||||||
|
__hip_fp8_e4m3_fnuz::__default_interpret),
|
||||||
|
c10::Float8_e4m3fnuz::from_bits());
|
||||||
|
}
|
||||||
|
|
||||||
template <typename Tout, typename Tin>
|
template <typename Tout, typename Tin>
|
||||||
__inline__ __device__ Tout vec_conversion(const Tin& x) {
|
__inline__ __device__ Tout vec_conversion(const Tin& x) {
|
||||||
return x;
|
return x;
|
||||||
|
@ -11,8 +11,8 @@
|
|||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t, typename fp8_type>
|
||||||
__global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
|
__global__ void scaled_fp8_quant_kernel(fp8_type* __restrict__ out,
|
||||||
const scalar_t* __restrict__ input,
|
const scalar_t* __restrict__ input,
|
||||||
const float* __restrict__ scale,
|
const float* __restrict__ scale,
|
||||||
int64_t num_elems) {
|
int64_t num_elems) {
|
||||||
@ -25,12 +25,13 @@ __global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
|
|||||||
out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
|
out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t, typename fp8_type>
|
||||||
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
||||||
FP8_TYPE* __restrict__ out, float* __restrict__ scale,
|
fp8_type* __restrict__ out, float* __restrict__ scale,
|
||||||
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
|
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
|
||||||
const int hidden_size) {
|
const int hidden_size) {
|
||||||
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
|
float const min_scaling_factor =
|
||||||
|
1.0f / (fp8_e4m3_adjusted_max_v<fp8_type> * 512.f);
|
||||||
|
|
||||||
int const tid = threadIdx.x;
|
int const tid = threadIdx.x;
|
||||||
int const token_idx = blockIdx.x;
|
int const token_idx = blockIdx.x;
|
||||||
@ -38,7 +39,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
|||||||
// Use int64 to avoid overflowing an int32 when calculating this offset
|
// Use int64 to avoid overflowing an int32 when calculating this offset
|
||||||
int64_t offset = static_cast<int64_t>(token_idx) * hidden_size;
|
int64_t offset = static_cast<int64_t>(token_idx) * hidden_size;
|
||||||
scalar_t const* __restrict__ token_input = &input[offset];
|
scalar_t const* __restrict__ token_input = &input[offset];
|
||||||
FP8_TYPE* __restrict__ token_output = &out[offset];
|
fp8_type* __restrict__ token_output = &out[offset];
|
||||||
|
|
||||||
// For vectorization, token_input and token_output pointers need to be
|
// For vectorization, token_input and token_output pointers need to be
|
||||||
// aligned at 8-byte and 4-byte addresses respectively.
|
// aligned at 8-byte and 4-byte addresses respectively.
|
||||||
@ -66,7 +67,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
|||||||
token_scale = block_absmax_val_maybe;
|
token_scale = block_absmax_val_maybe;
|
||||||
}
|
}
|
||||||
// token scale computation
|
// token scale computation
|
||||||
token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor);
|
token_scale = max(token_scale / fp8_e4m3_adjusted_max_v<fp8_type>,
|
||||||
|
min_scaling_factor);
|
||||||
scale[token_idx] = token_scale;
|
scale[token_idx] = token_scale;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
@ -77,7 +79,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
|||||||
token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
|
token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
|
||||||
} else {
|
} else {
|
||||||
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||||
token_output[i] = scaled_fp8_conversion<false>(
|
token_output[i] = scaled_fp8_conversion<false, fp8_type>(
|
||||||
static_cast<float>(token_input[i]), token_scale);
|
static_cast<float>(token_input[i]), token_scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -96,10 +98,14 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
|||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
|
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
|
||||||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
VLLM_DISPATCH_FP8_TYPES(
|
||||||
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
|
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
|
||||||
scale.data_ptr<float>(), num_elems);
|
vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||||
|
<<<grid, block, 0, stream>>>(
|
||||||
|
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
||||||
|
scale.data_ptr<float>(), num_elems);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -114,12 +120,18 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
|||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
|
input.scalar_type(), "scaled_fp8_quant_kernel_scalar_type", [&] {
|
||||||
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
|
VLLM_DISPATCH_FP8_TYPES(
|
||||||
scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems);
|
out.scalar_type(), "scaled_fp8_quant_kernel_fp8_type", [&] {
|
||||||
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
vllm::segmented_max_reduction<scalar_t, fp8_t>
|
||||||
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
|
<<<grid, block, 0, stream>>>(scale.data_ptr<float>(),
|
||||||
scale.data_ptr<float>(), num_elems);
|
input.data_ptr<scalar_t>(),
|
||||||
|
num_elems);
|
||||||
|
vllm::scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||||
|
<<<grid, block, 0, stream>>>(
|
||||||
|
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
||||||
|
scale.data_ptr<float>(), num_elems);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -138,12 +150,18 @@ void dynamic_per_token_scaled_fp8_quant(
|
|||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
VLLM_DISPATCH_FLOATING_TYPES(
|
VLLM_DISPATCH_FLOATING_TYPES(
|
||||||
input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
|
input.scalar_type(),
|
||||||
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
|
"dynamic_per_token_scaled_fp8_quant_kernel_scalar_type", [&] {
|
||||||
<<<grid, block, 0, stream>>>(
|
VLLM_DISPATCH_FP8_TYPES(
|
||||||
out.data_ptr<FP8_TYPE>(), scales.data_ptr<float>(),
|
out.scalar_type(),
|
||||||
input.data_ptr<scalar_t>(),
|
"dynamic_per_token_scaled_fp8_quant_kernel_fp8_type", [&] {
|
||||||
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t, fp8_t>
|
||||||
hidden_size);
|
<<<grid, block, 0, stream>>>(
|
||||||
|
out.data_ptr<fp8_t>(), scales.data_ptr<float>(),
|
||||||
|
input.data_ptr<scalar_t>(),
|
||||||
|
scale_ub.has_value() ? scale_ub->data_ptr<float>()
|
||||||
|
: nullptr,
|
||||||
|
hidden_size);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -7,18 +7,52 @@
|
|||||||
|
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
#include <c10/util/Float8_e4m3fn.h>
|
#include <c10/util/Float8_e4m3fn.h>
|
||||||
using FP8_TYPE = c10::Float8_e4m3fn;
|
#define MAYBE_HOST_DEVICE C10_HOST_DEVICE
|
||||||
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
|
|
||||||
std::numeric_limits<FP8_TYPE>::max();
|
|
||||||
#else
|
#else
|
||||||
|
#include <ATen/hip/HIPContext.h>
|
||||||
|
#include <c10/util/Float8_e4m3fn.h>
|
||||||
#include <c10/util/Float8_e4m3fnuz.h>
|
#include <c10/util/Float8_e4m3fnuz.h>
|
||||||
#include "amd/quant_utils.cuh"
|
#include "amd/quant_utils.cuh"
|
||||||
using FP8_TYPE = c10::Float8_e4m3fnuz;
|
// ROCm doesn't seem to need C10_HOST_DEVICE for static constexpr
|
||||||
// Using the default max value from pytorch (240.0) will cause accuracy
|
#define MAYBE_HOST_DEVICE
|
||||||
// issue when running dynamic quantization. Here use 224.0f for rocm.
|
|
||||||
constexpr auto FP8_E4M3_MAX = 224.0f;
|
|
||||||
#endif
|
#endif
|
||||||
constexpr static auto kFp8Type = c10::CppTypeToScalarType<FP8_TYPE>::value;
|
|
||||||
|
// Determines the preferred FP8 type for the current platform.
|
||||||
|
// Note that for CUDA this just returns true,
|
||||||
|
// but on ROCm it will check device props.
|
||||||
|
static bool is_fp8_ocp() {
|
||||||
|
#ifndef USE_ROCM
|
||||||
|
return true;
|
||||||
|
#else
|
||||||
|
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||||
|
std::string device_arch = dprops->gcnArchName;
|
||||||
|
size_t substring = device_arch.find("gfx94");
|
||||||
|
return substring == std::string::npos;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct fp8_e4m3_adjusted_max;
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct fp8_e4m3_adjusted_max<c10::Float8_e4m3fn> {
|
||||||
|
static constexpr c10::Float8_e4m3fn val() {
|
||||||
|
return std::numeric_limits<c10::Float8_e4m3fn>::max();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Using the default max value from pytorch (240.0 0x7F) will cause accuracy
|
||||||
|
// issues when running dynamic quantization. Here use 224.0 0x7E for rocm.
|
||||||
|
template <>
|
||||||
|
struct fp8_e4m3_adjusted_max<c10::Float8_e4m3fnuz> {
|
||||||
|
static constexpr c10::Float8_e4m3fnuz val() {
|
||||||
|
return c10::Float8_e4m3fnuz(0x7E, c10::Float8_e4m3fnuz::from_bits());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
MAYBE_HOST_DEVICE static constexpr T fp8_e4m3_adjusted_max_v =
|
||||||
|
fp8_e4m3_adjusted_max<T>::val();
|
||||||
|
|
||||||
namespace vllm {
|
namespace vllm {
|
||||||
|
|
||||||
@ -32,8 +66,8 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
|||||||
return old;
|
return old;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <bool is_scale_inverted>
|
template <bool is_scale_inverted, typename fp8_type>
|
||||||
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
|
__device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
|
||||||
float const scale) {
|
float const scale) {
|
||||||
float x = 0.0f;
|
float x = 0.0f;
|
||||||
if constexpr (is_scale_inverted) {
|
if constexpr (is_scale_inverted) {
|
||||||
@ -42,15 +76,13 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
|
|||||||
x = val / scale;
|
x = val / scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
float r = fmax(-fp8_e4m3_adjusted_max_v<fp8_type>,
|
||||||
|
fmin(x, fp8_e4m3_adjusted_max_v<fp8_type>));
|
||||||
#ifndef USE_ROCM
|
#ifndef USE_ROCM
|
||||||
return static_cast<c10::Float8_e4m3fn>(r);
|
return static_cast<fp8_type>(r);
|
||||||
#else
|
#else
|
||||||
// Use hardware cvt instruction for fp8 on rocm
|
// Use hardware cvt instruction for fp8 on rocm
|
||||||
return c10::Float8_e4m3fnuz(
|
return fp8::cvt_c10<fp8_type>(r);
|
||||||
__hip_cvt_float_to_fp8(r, fp8::fp8_type::__default_saturation,
|
|
||||||
fp8::fp8_type::__default_interpret),
|
|
||||||
c10::Float8_e4m3fnuz::from_bits());
|
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -60,7 +92,7 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
|
|||||||
// So to get the right answer, *scale needs to be initialized to
|
// So to get the right answer, *scale needs to be initialized to
|
||||||
// a value <= 0.0 and we need to wait for all thread blocks to
|
// a value <= 0.0 and we need to wait for all thread blocks to
|
||||||
// finish before consuming *scale.
|
// finish before consuming *scale.
|
||||||
template <typename scalar_t>
|
template <typename scalar_t, typename fp8_type>
|
||||||
__global__ void segmented_max_reduction(float* __restrict__ scale,
|
__global__ void segmented_max_reduction(float* __restrict__ scale,
|
||||||
const scalar_t* __restrict__ input,
|
const scalar_t* __restrict__ input,
|
||||||
int64_t num_elems) {
|
int64_t num_elems) {
|
||||||
@ -91,7 +123,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
|
|||||||
// Finally, since cache[0] contains the maximum for this thread block,
|
// Finally, since cache[0] contains the maximum for this thread block,
|
||||||
// atomically write the max to the target location
|
// atomically write the max to the target location
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX);
|
atomicMaxFloat(scale, cache[0] / fp8_e4m3_adjusted_max_v<fp8_type>);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -123,13 +155,13 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
|
|||||||
return absmax_val;
|
return absmax_val;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t, bool is_scale_inverted>
|
template <typename scalar_t, bool is_scale_inverted, typename fp8_type>
|
||||||
__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
|
__device__ void scaled_fp8_conversion_vec(fp8_type* __restrict__ out,
|
||||||
scalar_t const* __restrict__ input,
|
scalar_t const* __restrict__ input,
|
||||||
float const scale,
|
float const scale,
|
||||||
int64_t const num_elems,
|
int64_t const num_elems,
|
||||||
int const tid, int const step) {
|
int const tid, int const step) {
|
||||||
using float8x4_t = q8x4_t<FP8_TYPE>;
|
using float8x4_t = q8x4_t<fp8_type>;
|
||||||
// Vectorized input/output to better utilize memory bandwidth.
|
// Vectorized input/output to better utilize memory bandwidth.
|
||||||
auto const* vectorized_in = reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
auto const* vectorized_in = reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
||||||
auto* vectorized_out = reinterpret_cast<float8x4_t*>(out);
|
auto* vectorized_out = reinterpret_cast<float8x4_t*>(out);
|
||||||
@ -141,22 +173,22 @@ __device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
|
|||||||
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
||||||
float8x4_t out_vec;
|
float8x4_t out_vec;
|
||||||
|
|
||||||
out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
|
out_vec.x = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||||
static_cast<float>(in_vec.x), scale);
|
static_cast<float>(in_vec.x), scale);
|
||||||
out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
|
out_vec.y = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||||
static_cast<float>(in_vec.y), scale);
|
static_cast<float>(in_vec.y), scale);
|
||||||
out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
|
out_vec.z = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||||
static_cast<float>(in_vec.z), scale);
|
static_cast<float>(in_vec.z), scale);
|
||||||
out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
|
out_vec.w = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||||
static_cast<float>(in_vec.w), scale);
|
static_cast<float>(in_vec.w), scale);
|
||||||
vectorized_out[i] = out_vec;
|
vectorized_out[i] = out_vec;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle the remaining elements if num_elems is not divisible by 4
|
// Handle the remaining elements if num_elems is not divisible by 4
|
||||||
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
||||||
out[i] = scaled_fp8_conversion<is_scale_inverted>(
|
out[i] = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
|
||||||
static_cast<float>(input[i]), scale);
|
static_cast<float>(input[i]), scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace vllm
|
} // namespace vllm
|
||||||
|
@ -144,6 +144,9 @@ void rms_norm_dynamic_per_token_quant(
|
|||||||
torch::Tensor& scales, // [num_tokens]
|
torch::Tensor& scales, // [num_tokens]
|
||||||
double const var_epsilon, // Variance epsilon used in norm calculation
|
double const var_epsilon, // Variance epsilon used in norm calculation
|
||||||
std::optional<at::Tensor> scale_ub, std::optional<at::Tensor> residual) {
|
std::optional<at::Tensor> scale_ub, std::optional<at::Tensor> residual) {
|
||||||
|
static c10::ScalarType kFp8Type = is_fp8_ocp()
|
||||||
|
? c10::ScalarType::Float8_e4m3fn
|
||||||
|
: c10::ScalarType::Float8_e4m3fnuz;
|
||||||
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
|
TORCH_CHECK(out.dtype() == kFp8Type || out.dtype() == torch::kInt8);
|
||||||
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
|
TORCH_CHECK(out.is_contiguous() && input.is_contiguous());
|
||||||
|
|
||||||
|
@ -31,9 +31,11 @@ static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ FP8_TYPE float_to_fp8(float const x) {
|
template <typename fp8_type>
|
||||||
float const r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
static __device__ __forceinline__ fp8_type float_to_fp8(float const x) {
|
||||||
return static_cast<FP8_TYPE>(r);
|
float const r = fmax(-fp8_e4m3_adjusted_max_v<fp8_type>,
|
||||||
|
fmin(x, fp8_e4m3_adjusted_max_v<fp8_type>));
|
||||||
|
return static_cast<fp8_type>(r);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename quant_type_t, bool is_scale_inverted, typename enable = void>
|
template <typename quant_type_t, bool is_scale_inverted, typename enable = void>
|
||||||
@ -54,15 +56,16 @@ struct ScaledQuant<
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename quant_type_t, bool is_scale_inverted>
|
template <typename quant_type_t, bool is_scale_inverted>
|
||||||
struct ScaledQuant<
|
struct ScaledQuant<quant_type_t, is_scale_inverted,
|
||||||
quant_type_t, is_scale_inverted,
|
typename std::enable_if_t<
|
||||||
typename std::enable_if_t<std::is_same_v<quant_type_t, FP8_TYPE>>> {
|
std::is_same_v<quant_type_t, c10::Float8_e4m3fn> ||
|
||||||
|
std::is_same_v<quant_type_t, c10::Float8_e4m3fnuz>>> {
|
||||||
static __device__ __forceinline__ quant_type_t quant_fn(float const x,
|
static __device__ __forceinline__ quant_type_t quant_fn(float const x,
|
||||||
float const scale) {
|
float const scale) {
|
||||||
if constexpr (is_scale_inverted) {
|
if constexpr (is_scale_inverted) {
|
||||||
return float_to_fp8(x * scale);
|
return float_to_fp8<quant_type_t>(x * scale);
|
||||||
} else {
|
} else {
|
||||||
return float_to_fp8(x / scale);
|
return float_to_fp8<quant_type_t>(x / scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -4,7 +4,6 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
// Include both AMD and NVIDIA fp8 types to avoid circular import
|
// Include both AMD and NVIDIA fp8 types to avoid circular import
|
||||||
// TODO(luka/varun) use FP8_TYPE instead after refactoring
|
|
||||||
#include <c10/util/Float8_e4m3fnuz.h>
|
#include <c10/util/Float8_e4m3fnuz.h>
|
||||||
#include <c10/util/Float8_e4m3fn.h>
|
#include <c10/util/Float8_e4m3fn.h>
|
||||||
|
|
||||||
|
@ -9,8 +9,7 @@ from vllm.platforms import current_platform
|
|||||||
# Using the default value (240.0) from pytorch will cause accuracy
|
# Using the default value (240.0) from pytorch will cause accuracy
|
||||||
# issue on dynamic quantization models. Here use 224.0 for rocm.
|
# issue on dynamic quantization models. Here use 224.0 for rocm.
|
||||||
ROCM_FP8_MAX = 224.0
|
ROCM_FP8_MAX = 224.0
|
||||||
FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm() \
|
FP8_DTYPE = current_platform.fp8_dtype()
|
||||||
else torch.float8_e4m3fn
|
|
||||||
|
|
||||||
|
|
||||||
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
|
def as_float32_tensor(x: Union[float, torch.tensor]) -> torch.tensor:
|
||||||
|
@ -32,11 +32,8 @@ def scaled_mm_torch(a: torch.Tensor,
|
|||||||
|
|
||||||
def get_8bit_types():
|
def get_8bit_types():
|
||||||
types = [torch.int8]
|
types = [torch.int8]
|
||||||
supports_fp8 = current_platform.has_device_capability(89)
|
if current_platform.supports_fp8():
|
||||||
if current_platform.is_rocm() and supports_fp8:
|
types.append(current_platform.fp8_dtype())
|
||||||
types.append(torch.float8_e4m3fnuz)
|
|
||||||
elif current_platform.is_cuda() and supports_fp8:
|
|
||||||
types.append(torch.float8_e4m3fn)
|
|
||||||
return types
|
return types
|
||||||
|
|
||||||
|
|
||||||
|
@ -103,8 +103,7 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
|
|||||||
assert attn._v_scale == 1.0
|
assert attn._v_scale == 1.0
|
||||||
|
|
||||||
if current_platform.is_cuda():
|
if current_platform.is_cuda():
|
||||||
if current_platform.has_device_capability(
|
if current_platform.supports_fp8() and not force_marlin:
|
||||||
89) and not force_marlin:
|
|
||||||
# For GPUs with hardware support, we keep weights in fp8
|
# For GPUs with hardware support, we keep weights in fp8
|
||||||
assert fc1.weight.dtype == torch.float8_e4m3fn
|
assert fc1.weight.dtype == torch.float8_e4m3fn
|
||||||
else:
|
else:
|
||||||
@ -112,11 +111,9 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
|
|||||||
# for weight-only quantization using Marlin kernels
|
# for weight-only quantization using Marlin kernels
|
||||||
assert fc1.weight.dtype == torch.int32
|
assert fc1.weight.dtype == torch.int32
|
||||||
elif current_platform.is_rocm():
|
elif current_platform.is_rocm():
|
||||||
# Only MI300 and above support quantization='fp8'
|
if current_platform.supports_fp8() and not force_marlin:
|
||||||
if current_platform.has_device_capability(
|
|
||||||
94) and not force_marlin:
|
|
||||||
# For GPUs with hardware support, we keep weights in fp8
|
# For GPUs with hardware support, we keep weights in fp8
|
||||||
assert fc1.weight.dtype == torch.float8_e4m3fnuz
|
assert fc1.weight.dtype == current_platform.fp8_dtype()
|
||||||
else: # unsupported ROCm platform
|
else: # unsupported ROCm platform
|
||||||
pytest.skip(
|
pytest.skip(
|
||||||
"Skip `test_load_fp16_model`. "
|
"Skip `test_load_fp16_model`. "
|
||||||
|
@ -478,16 +478,16 @@ def cutlass_scaled_mm(a: torch.Tensor,
|
|||||||
out_dtype: torch.dtype,
|
out_dtype: torch.dtype,
|
||||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
`cutlass_scaled_mm` implements a fused version of
|
`cutlass_scaled_mm` implements a fused version of
|
||||||
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
|
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
|
||||||
where scale_a * a and scale_b * b are implemented using numpy-style
|
where scale_a * a and scale_b * b are implemented using numpy-style
|
||||||
broadcasting.
|
broadcasting.
|
||||||
|
|
||||||
In order to support blockwise scaling like found in DeepSeek V3 we also
|
In order to support blockwise scaling like found in DeepSeek V3 we also
|
||||||
support extended "group" broadcast rules. We extend the numpy-style
|
support extended "group" broadcast rules. We extend the numpy-style
|
||||||
broadcasting rules with the following rule:
|
broadcasting rules with the following rule:
|
||||||
"if the extent of a dimension in the source shape is between 1 and
|
"if the extent of a dimension in the source shape is between 1 and
|
||||||
corresponding extent in the target shape we repeat each element along
|
corresponding extent in the target shape we repeat each element along
|
||||||
that dimension src_shape[dim] // target_shape[dim] times consecutively"
|
that dimension src_shape[dim] // target_shape[dim] times consecutively"
|
||||||
example if we have:
|
example if we have:
|
||||||
a = [[1, 2], and target_shape = (2, 4)
|
a = [[1, 2], and target_shape = (2, 4)
|
||||||
@ -564,7 +564,7 @@ def cutlass_sparse_compress(a: torch.Tensor) \
|
|||||||
with Cutlass sparse kernels.
|
with Cutlass sparse kernels.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
a (torch.Tensor):
|
a (torch.Tensor):
|
||||||
The input tensor to be compressed. Must have one of the following data types:
|
The input tensor to be compressed. Must have one of the following data types:
|
||||||
- `torch.int8`
|
- `torch.int8`
|
||||||
- `torch.float8_e4m3fn`
|
- `torch.float8_e4m3fn`
|
||||||
@ -572,7 +572,7 @@ def cutlass_sparse_compress(a: torch.Tensor) \
|
|||||||
- `torch.float16`
|
- `torch.float16`
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[torch.Tensor, torch.Tensor]:
|
tuple[torch.Tensor, torch.Tensor]:
|
||||||
A tuple containing:
|
A tuple containing:
|
||||||
- `a_nzs` (torch.Tensor): A tensor containing non-zero elements of `a`.
|
- `a_nzs` (torch.Tensor): A tensor containing non-zero elements of `a`.
|
||||||
- `a_meta` (torch.Tensor): A tensor containing metadata for the sparse representation.
|
- `a_meta` (torch.Tensor): A tensor containing metadata for the sparse representation.
|
||||||
@ -875,9 +875,8 @@ def scaled_fp8_quant(
|
|||||||
# This code assumes batch_dim and num_tokens are flattened
|
# This code assumes batch_dim and num_tokens are flattened
|
||||||
assert (input.ndim == 2)
|
assert (input.ndim == 2)
|
||||||
shape: Union[tuple[int, int], torch.Size] = input.shape
|
shape: Union[tuple[int, int], torch.Size] = input.shape
|
||||||
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
|
# For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
|
||||||
out_dtype: torch.dtype = torch.float8_e4m3fnuz \
|
out_dtype: torch.dtype = current_platform.fp8_dtype()
|
||||||
if current_platform.is_rocm() else torch.float8_e4m3fn
|
|
||||||
if num_token_padding:
|
if num_token_padding:
|
||||||
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
shape = (max(num_token_padding, input.shape[0]), shape[1])
|
||||||
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
output = torch.empty(shape, device=input.device, dtype=out_dtype)
|
||||||
@ -908,7 +907,7 @@ def allspark_repack_weight(
|
|||||||
has_zp: bool = False
|
has_zp: bool = False
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Rearrange qweight, scale, and zero_point(if asymmetric) to n32k16 format
|
Rearrange qweight, scale, and zero_point(if asymmetric) to n32k16 format
|
||||||
for Ampere W8A16 Fused Gemm kernel
|
for Ampere W8A16 Fused Gemm kernel
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -917,10 +916,10 @@ def allspark_repack_weight(
|
|||||||
zero_point: fp16/bf16 weight zero_point tensor, 1 x n format.
|
zero_point: fp16/bf16 weight zero_point tensor, 1 x n format.
|
||||||
Must be provided for asymmetric quantization.
|
Must be provided for asymmetric quantization.
|
||||||
has_zp: if use symmetric quantization, has_zp = False.
|
has_zp: if use symmetric quantization, has_zp = False.
|
||||||
if use asymmetric quantization, has_zp = True.
|
if use asymmetric quantization, has_zp = True.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] :
|
tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] :
|
||||||
rearranged weight, scale, and optionally zero_point.
|
rearranged weight, scale, and optionally zero_point.
|
||||||
"""
|
"""
|
||||||
K = qweight.shape[0]
|
K = qweight.shape[0]
|
||||||
|
@ -226,7 +226,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
|||||||
CompressedTensorsW8A8Fp8)
|
CompressedTensorsW8A8Fp8)
|
||||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
Fp8LinearGenericOp, current_platform_fp8_dtype, is_fp8)
|
Fp8LinearGenericOp, is_fp8)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
scaled_quantize)
|
scaled_quantize)
|
||||||
from vllm.model_executor.layers.rotary_embedding import (
|
from vllm.model_executor.layers.rotary_embedding import (
|
||||||
@ -1238,7 +1238,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
W_Q_UK, W_Q_UK_scales = scaled_quantize(
|
W_Q_UK, W_Q_UK_scales = scaled_quantize(
|
||||||
W_Q_UK,
|
W_Q_UK,
|
||||||
self.reqaunt_weight_group_shape,
|
self.reqaunt_weight_group_shape,
|
||||||
quant_dtype=current_platform_fp8_dtype)
|
quant_dtype=current_platform.fp8_dtype())
|
||||||
# For FP8 save the transpose so we can use
|
# For FP8 save the transpose so we can use
|
||||||
# `apply_w8a8_block_fp8_linear` directly
|
# `apply_w8a8_block_fp8_linear` directly
|
||||||
self.W_Q_UK = W_Q_UK.T.contiguous()
|
self.W_Q_UK = W_Q_UK.T.contiguous()
|
||||||
@ -1255,7 +1255,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
|
|||||||
W_UV_O, W_UV_O_scales = scaled_quantize(
|
W_UV_O, W_UV_O_scales = scaled_quantize(
|
||||||
W_UV_O,
|
W_UV_O,
|
||||||
self.reqaunt_weight_group_shape,
|
self.reqaunt_weight_group_shape,
|
||||||
quant_dtype=current_platform_fp8_dtype)
|
quant_dtype=current_platform.fp8_dtype())
|
||||||
# For FP8 save the transpose so we can use
|
# For FP8 save the transpose so we can use
|
||||||
# `apply_w8a8_block_fp8_linear` directly
|
# `apply_w8a8_block_fp8_linear` directly
|
||||||
self.W_UV_O = W_UV_O.T.contiguous()
|
self.W_UV_O = W_UV_O.T.contiguous()
|
||||||
|
@ -158,8 +158,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
layer.w2_input_scale = torch.nn.Parameter(
|
layer.w2_input_scale = torch.nn.Parameter(
|
||||||
layer.w2_input_scale.max(), requires_grad=False)
|
layer.w2_input_scale.max(), requires_grad=False)
|
||||||
|
|
||||||
# If rocm, normalize the weights and scales to e4m3fnuz
|
if current_platform.is_fp8_fnuz():
|
||||||
if current_platform.is_rocm():
|
|
||||||
# Normalize the weights and scales
|
# Normalize the weights and scales
|
||||||
w13_weight, w13_weight_scale, w13_input_scale = \
|
w13_weight, w13_weight_scale, w13_input_scale = \
|
||||||
normalize_e4m3fn_to_e4m3fnuz(
|
normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
@ -42,7 +42,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
logical_widths=layer.logical_widths,
|
logical_widths=layer.logical_widths,
|
||||||
)
|
)
|
||||||
|
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_fp8_fnuz():
|
||||||
input_scale = getattr(layer, 'input_scale', None)
|
input_scale = getattr(layer, 'input_scale', None)
|
||||||
|
|
||||||
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
@ -60,7 +60,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|||||||
elif self.strategy == QuantizationStrategy.CHANNEL:
|
elif self.strategy == QuantizationStrategy.CHANNEL:
|
||||||
weight = layer.weight
|
weight = layer.weight
|
||||||
|
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_fp8_fnuz():
|
||||||
input_scale = getattr(layer, 'input_scale', None)
|
input_scale = getattr(layer, 'input_scale', None)
|
||||||
|
|
||||||
weight, weight_scale, input_scale = \
|
weight, weight_scale, input_scale = \
|
||||||
|
@ -127,7 +127,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
|
|||||||
|
|
||||||
weight = layer.weight
|
weight = layer.weight
|
||||||
|
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_fp8_fnuz():
|
||||||
weight, weight_scale, input_scale = \
|
weight, weight_scale, input_scale = \
|
||||||
normalize_e4m3fn_to_e4m3fnuz(
|
normalize_e4m3fn_to_e4m3fnuz(
|
||||||
weight=weight,
|
weight=weight,
|
||||||
|
@ -270,7 +270,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
# TODO(rob): refactor block quant into separate class.
|
# TODO(rob): refactor block quant into separate class.
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
assert self.quant_config.activation_scheme == "dynamic"
|
assert self.quant_config.activation_scheme == "dynamic"
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_fp8_fnuz():
|
||||||
weight, weight_scale_inv, _ = \
|
weight, weight_scale_inv, _ = \
|
||||||
normalize_e4m3fn_to_e4m3fnuz(
|
normalize_e4m3fn_to_e4m3fnuz(
|
||||||
weight=layer.weight,
|
weight=layer.weight,
|
||||||
@ -327,8 +327,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||||||
weight = layer.weight
|
weight = layer.weight
|
||||||
weight_scale = layer.weight_scale
|
weight_scale = layer.weight_scale
|
||||||
|
|
||||||
# If rocm, use float8_e4m3fnuz.
|
if current_platform.is_fp8_fnuz():
|
||||||
if current_platform.is_rocm():
|
|
||||||
weight, weight_scale, input_scale = \
|
weight, weight_scale, input_scale = \
|
||||||
normalize_e4m3fn_to_e4m3fnuz(
|
normalize_e4m3fn_to_e4m3fnuz(
|
||||||
weight=weight,
|
weight=weight,
|
||||||
@ -533,7 +532,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
# TODO (rob): refactor block quant into separate class.
|
# TODO (rob): refactor block quant into separate class.
|
||||||
if self.block_quant:
|
if self.block_quant:
|
||||||
assert self.quant_config.activation_scheme == "dynamic"
|
assert self.quant_config.activation_scheme == "dynamic"
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_fp8_fnuz():
|
||||||
w13_weight, w13_weight_scale_inv, w13_input_scale = \
|
w13_weight, w13_weight_scale_inv, w13_input_scale = \
|
||||||
normalize_e4m3fn_to_e4m3fnuz(
|
normalize_e4m3fn_to_e4m3fnuz(
|
||||||
layer.w13_weight, layer.w13_weight_scale_inv,
|
layer.w13_weight, layer.w13_weight_scale_inv,
|
||||||
@ -559,9 +558,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
|
|
||||||
# If checkpoint is fp16, quantize in place.
|
# If checkpoint is fp16, quantize in place.
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
# If rocm, use float8_e4m3fnuz as dtype
|
fp8_dtype = current_platform.fp8_dtype()
|
||||||
fp8_dtype = torch.float8_e4m3fnuz \
|
|
||||||
if current_platform.is_rocm() else torch.float8_e4m3fn
|
|
||||||
w13_weight = torch.empty_like(layer.w13_weight.data,
|
w13_weight = torch.empty_like(layer.w13_weight.data,
|
||||||
dtype=fp8_dtype)
|
dtype=fp8_dtype)
|
||||||
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
|
||||||
@ -608,8 +605,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||||||
layer.w13_input_scale.max(), requires_grad=False)
|
layer.w13_input_scale.max(), requires_grad=False)
|
||||||
layer.w2_input_scale = torch.nn.Parameter(
|
layer.w2_input_scale = torch.nn.Parameter(
|
||||||
layer.w2_input_scale.max(), requires_grad=False)
|
layer.w2_input_scale.max(), requires_grad=False)
|
||||||
# If rocm, normalize the weights and scales to e4m3fnuz
|
if current_platform.is_fp8_fnuz():
|
||||||
if current_platform.is_rocm():
|
|
||||||
# Normalize the weights and scales
|
# Normalize the weights and scales
|
||||||
w13_weight, w13_weight_scale, w13_input_scale = \
|
w13_weight, w13_weight_scale, w13_input_scale = \
|
||||||
normalize_e4m3fn_to_e4m3fnuz(
|
normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
@ -142,8 +142,7 @@ class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
|||||||
layer.w2_input_scale = torch.nn.Parameter(
|
layer.w2_input_scale = torch.nn.Parameter(
|
||||||
layer.w2_input_scale.max(), requires_grad=False)
|
layer.w2_input_scale.max(), requires_grad=False)
|
||||||
|
|
||||||
# If rocm, normalize the weights and scales to e4m3fnuz
|
if current_platform.is_fp8_fnuz():
|
||||||
if current_platform.is_rocm():
|
|
||||||
# Normalize the weights and scales
|
# Normalize the weights and scales
|
||||||
w13_weight, w13_weight_scale, w13_input_scale = \
|
w13_weight, w13_weight_scale, w13_input_scale = \
|
||||||
normalize_e4m3fn_to_e4m3fnuz(
|
normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
@ -39,7 +39,7 @@ class QuarkW8A8Fp8(QuarkScheme):
|
|||||||
logical_widths=layer.logical_widths,
|
logical_widths=layer.logical_widths,
|
||||||
)
|
)
|
||||||
|
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_fp8_fnuz():
|
||||||
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||||
weight=weight,
|
weight=weight,
|
||||||
weight_scale=max_w_scale,
|
weight_scale=max_w_scale,
|
||||||
@ -55,7 +55,7 @@ class QuarkW8A8Fp8(QuarkScheme):
|
|||||||
elif self.qscheme == "per_channel":
|
elif self.qscheme == "per_channel":
|
||||||
weight = layer.weight
|
weight = layer.weight
|
||||||
|
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_fp8_fnuz():
|
||||||
weight, weight_scale, input_scale = \
|
weight, weight_scale, input_scale = \
|
||||||
normalize_e4m3fn_to_e4m3fnuz(
|
normalize_e4m3fn_to_e4m3fnuz(
|
||||||
weight=weight,
|
weight=weight,
|
||||||
|
@ -22,10 +22,6 @@ from vllm.utils import direct_register_custom_op
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
current_platform_fp8_dtype = (torch.float8_e4m3fnuz
|
|
||||||
if current_platform.is_rocm() else
|
|
||||||
torch.float8_e4m3fn)
|
|
||||||
|
|
||||||
|
|
||||||
def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
|
def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
|
||||||
if isinstance(x, torch.Tensor):
|
if isinstance(x, torch.Tensor):
|
||||||
@ -165,9 +161,7 @@ def input_to_float8(
|
|||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""This function quantizes input values to float8 values "
|
"""This function quantizes input values to float8 values "
|
||||||
"with tensor-wise quantization."""
|
"with tensor-wise quantization."""
|
||||||
if dtype is None:
|
dtype = current_platform.fp8_dtype() if dtype is None else dtype
|
||||||
dtype = (torch.float8_e4m3fnuz
|
|
||||||
if current_platform.is_rocm() else torch.float8_e4m3fn)
|
|
||||||
finfo = torch.finfo(dtype)
|
finfo = torch.finfo(dtype)
|
||||||
min_val, max_val = x.aminmax()
|
min_val, max_val = x.aminmax()
|
||||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||||
@ -311,9 +305,7 @@ def per_token_group_quant_fp8(
|
|||||||
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||||
scaling factor for quantization.
|
scaling factor for quantization.
|
||||||
"""
|
"""
|
||||||
if dtype is None:
|
dtype = current_platform.fp8_dtype() if dtype is None else dtype
|
||||||
dtype = (torch.float8_e4m3fnuz
|
|
||||||
if current_platform.is_rocm() else torch.float8_e4m3fn)
|
|
||||||
assert (x.shape[-1] % group_size == 0), (
|
assert (x.shape[-1] % group_size == 0), (
|
||||||
f"the last dimension of `x` {x.shape[-1]} must be divisible "
|
f"the last dimension of `x` {x.shape[-1]} must be divisible "
|
||||||
f"by `group_size` {group_size}")
|
f"by `group_size` {group_size}")
|
||||||
|
@ -293,6 +293,10 @@ class CudaPlatformBase(Platform):
|
|||||||
def get_device_communicator_cls(cls) -> str:
|
def get_device_communicator_cls(cls) -> str:
|
||||||
return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
|
return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def supports_fp8(cls) -> bool:
|
||||||
|
return cls.has_device_capability(89)
|
||||||
|
|
||||||
|
|
||||||
# NVML utils
|
# NVML utils
|
||||||
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||||
|
@ -330,6 +330,36 @@ class Platform:
|
|||||||
"""
|
"""
|
||||||
return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa
|
return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def supports_fp8(cls) -> bool:
|
||||||
|
"""
|
||||||
|
Returns whether the current platform supports FP8 types.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_fp8_fnuz(cls) -> bool:
|
||||||
|
"""
|
||||||
|
Returns whether the preferred FP8 type is FNUZ on the current platform.
|
||||||
|
|
||||||
|
There are two representations of FP8, OCP FP8 and FNUZ FP8.
|
||||||
|
The OCP specification can be found at https://tinyurl.com/b7jvwpft.
|
||||||
|
The FNUZ specification can be found at https://tinyurl.com/5n6hwwu5.
|
||||||
|
|
||||||
|
AMD's MI300 and MI325 have native hardware support for FNUZ. All other
|
||||||
|
hardware has converged on the OCP FP8 standard.
|
||||||
|
"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fp8_dtype(cls) -> torch.dtype:
|
||||||
|
"""
|
||||||
|
Returns the preferred FP8 type on the current platform.
|
||||||
|
|
||||||
|
See the documentation for is_fp8_fnuz for details.
|
||||||
|
"""
|
||||||
|
return torch.float8_e4m3fn
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def use_all_gather(cls) -> bool:
|
def use_all_gather(cls) -> bool:
|
||||||
"""
|
"""
|
||||||
|
@ -231,3 +231,20 @@ class RocmPlatform(Platform):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_device_communicator_cls(cls) -> str:
|
def get_device_communicator_cls(cls) -> str:
|
||||||
return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
|
return "vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator" # noqa
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def supports_fp8(cls) -> bool:
|
||||||
|
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
|
||||||
|
return any(gfx in gcn_arch for gfx in ['gfx94', 'gfx95', 'gfx12'])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_fp8_fnuz(cls) -> bool:
|
||||||
|
# only device 0 is checked, this assumes MI300 platforms are homogeneous
|
||||||
|
return 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fp8_dtype(cls) -> torch.dtype:
|
||||||
|
if cls.is_fp8_fnuz():
|
||||||
|
return torch.float8_e4m3fnuz
|
||||||
|
else:
|
||||||
|
return torch.float8_e4m3fn
|
||||||
|
@ -219,7 +219,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
|||||||
CompressedTensorsW8A8Fp8)
|
CompressedTensorsW8A8Fp8)
|
||||||
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
||||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||||
Fp8LinearGenericOp, current_platform_fp8_dtype, is_fp8)
|
Fp8LinearGenericOp, is_fp8)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
scaled_quantize)
|
scaled_quantize)
|
||||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||||
@ -826,7 +826,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
W_Q_UK, W_Q_UK_scales = scaled_quantize(
|
W_Q_UK, W_Q_UK_scales = scaled_quantize(
|
||||||
W_Q_UK,
|
W_Q_UK,
|
||||||
self.reqaunt_weight_group_shape,
|
self.reqaunt_weight_group_shape,
|
||||||
quant_dtype=current_platform_fp8_dtype)
|
quant_dtype=current_platform.fp8_dtype())
|
||||||
# For FP8 save the transpose so we can use
|
# For FP8 save the transpose so we can use
|
||||||
# `apply_w8a8_block_fp8_linear` directly
|
# `apply_w8a8_block_fp8_linear` directly
|
||||||
self.W_Q_UK = W_Q_UK.T.contiguous()
|
self.W_Q_UK = W_Q_UK.T.contiguous()
|
||||||
@ -843,7 +843,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
W_UV_O, W_UV_O_scales = scaled_quantize(
|
W_UV_O, W_UV_O_scales = scaled_quantize(
|
||||||
W_UV_O,
|
W_UV_O,
|
||||||
self.reqaunt_weight_group_shape,
|
self.reqaunt_weight_group_shape,
|
||||||
quant_dtype=current_platform_fp8_dtype)
|
quant_dtype=current_platform.fp8_dtype())
|
||||||
# For FP8 save the transpose so we can use
|
# For FP8 save the transpose so we can use
|
||||||
# `apply_w8a8_block_fp8_linear` directly
|
# `apply_w8a8_block_fp8_linear` directly
|
||||||
self.W_UV_O = W_UV_O.T.contiguous()
|
self.W_UV_O = W_UV_O.T.contiguous()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user